diff --git a/onnxruntime/test/testdata/bart_tiny.onnx b/onnxruntime/test/testdata/bart_tiny.onnx new file mode 100644 index 0000000000000..ba50963637c61 Binary files /dev/null and b/onnxruntime/test/testdata/bart_tiny.onnx differ diff --git a/orttraining/orttraining/core/graph/optimizer/adam_optimizer_builder.cc b/orttraining/orttraining/core/graph/optimizer/adam_optimizer_builder.cc index 5f6345eaa22ed..abe0d8efeb53a 100644 --- a/orttraining/orttraining/core/graph/optimizer/adam_optimizer_builder.cc +++ b/orttraining/orttraining/core/graph/optimizer/adam_optimizer_builder.cc @@ -28,7 +28,6 @@ Status AdamOptimizerBuilder::Build( const std::string& gradient_name = gradient_argdefs[i].name; const TypeProto* const weight_type_proto = weight_argdefs[i].type_proto; const TypeProto* const gradient_type_proto = gradient_argdefs[i].type_proto; - weight_to_opt_mapping[weight_name] = {}; // Return either the input gradient/weight/mixed-precision-weight or updated gradient/weight/mixed-precision-weight. ArgDef output_gradient_argdef = gradient_argdefs[i]; @@ -38,6 +37,7 @@ Status AdamOptimizerBuilder::Build( // In distributed training, some weights may not be updated by all ranks. if (opt_configs[i].enabled) { + weight_to_opt_mapping[weight_name] = {}; // The type proto initializer for Update Count const std::string update_count_string = ADAM_UC_PREFIX + "_" + weight_name; // per weight optimizer requires a per weight update count TensorProto uc_tensor_proto; @@ -82,7 +82,6 @@ Status AdamOptimizerBuilder::Build( const auto element_type = opt_configs[i].use_mixed_precision_moments ? ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16 : ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT; - // Add first- and second-order momentums to input list. for (const auto& moments_prefix : MOMENTS_PREFIXES) { const std::string gradient_moment_name = moments_prefix + "_" + weight_name; diff --git a/orttraining/orttraining/core/graph/zero_optimizer_graph_builder.cc b/orttraining/orttraining/core/graph/zero_optimizer_graph_builder.cc index bf05b23691505..e7565dfc7a05d 100644 --- a/orttraining/orttraining/core/graph/zero_optimizer_graph_builder.cc +++ b/orttraining/orttraining/core/graph/zero_optimizer_graph_builder.cc @@ -250,6 +250,7 @@ static Status AddParameterPartition( //Partition the FP32 weight weight_views = AddPartitionsForParameter(graph, graph_defs, weight_argdef.name, view_shapes, updated_weight_names_map); ORT_ENFORCE(weight_views.size() == enabled.size()); + weight_partition_info[weight_argdef.name].weight_partitioned = true; // Add View for mixed precision weight. ArgDef mixed_precision_weight_argdef(opt_config.mixed_precision_weight_arg->Name(), opt_config.mixed_precision_weight_arg->TypeAsProto()); @@ -275,7 +276,7 @@ static Status AddParameterPartition( // Partition initial optimizer state if (enabled[i]) { - weight_partition_info[weight_argdef.name].view_name = weight_views[i].name; + weight_partition_info[weight_argdef.name].partition_name = weight_views[i].name; if (!initial_states.empty()) { ORT_ENFORCE(view_shapes.size() == 3, "Invalid view_shapes vector passed for partitioning."); diff --git a/orttraining/orttraining/core/optimizer/megatron_transformer.cc b/orttraining/orttraining/core/optimizer/megatron_transformer.cc index 340f510e5dff6..50a08c39ee625 100644 --- a/orttraining/orttraining/core/optimizer/megatron_transformer.cc +++ b/orttraining/orttraining/core/optimizer/megatron_transformer.cc @@ -3,6 +3,7 @@ #include "core/optimizer/initializer.h" #include "orttraining/core/framework/distributed_run_context.h" +#include "orttraining/core/graph/optimizer_builder.h" #include "orttraining/core/optimizer/megatron_transformer.h" #include "core/graph/graph_utils.h" #include "core/optimizer/utils.h" @@ -113,12 +114,32 @@ static uint32_t HashName(const std::string& name) { return hash; } +template +void MegatronTransformer::PartitionBufferByColumn(const T* input, + const int64_t row_count, + const int64_t column_count, + const int64_t column_stride, + const int stride, + std::vector& result) const { + const int64_t column_stride_partition = column_stride / horizontal_parallel_size_; + + const int64_t stride_partition_column_offset = horizontal_parallel_rank_ * column_stride_partition; + for (auto row_index = 0; row_index < row_count; row_index++) { + const auto row_offset = row_index * column_count; + for (auto stride_index = 0; stride_index < stride; stride_index++) { + const auto column_offset = row_offset + stride_index * column_stride + stride_partition_column_offset; + std::copy(input + column_offset, input + column_offset + column_stride_partition, std::back_inserter(result)); + } + } +} + bool MegatronTransformer::PartitionWeightByColumn(const Graph& graph, const NodeArg& input_arg, ONNX_NAMESPACE::TensorProto& initializer_partition, int stride) const { + const std::string original_name = input_arg.Name(); const ONNX_NAMESPACE::TensorProto* tensor_proto; - if (!graph.GetInitializedTensor(input_arg.Name(), tensor_proto)) { - LOGS_DEFAULT(WARNING) << "PartitionWeightByColumn: " << input_arg.Name() << " is not an initializer"; + if (!graph.GetInitializedTensor(original_name, tensor_proto)) { + LOGS_DEFAULT(WARNING) << "PartitionWeightByColumn: " << original_name << " is not an initializer"; return false; } auto data_type = tensor_proto->data_type(); @@ -130,11 +151,11 @@ bool MegatronTransformer::PartitionWeightByColumn(const Graph& graph, const Node if (rank == 2 && utils::HasDimValue(shape->dim(0)) && utils::HasDimValue(shape->dim(1))) { row_count = shape->dim(0).dim_value(); column_count = shape->dim(1).dim_value(); - weight_partition_info_[input_arg.Name()].original_dim = std::vector{row_count, column_count}; + weight_partition_info_[original_name].original_dim = std::vector{row_count, column_count}; } else if (rank == 1) { row_count = 1; column_count = shape->dim(0).dim_value(); - weight_partition_info_[input_arg.Name()].original_dim = std::vector{column_count}; + weight_partition_info_[original_name].original_dim = std::vector{column_count}; } else { LOGS_DEFAULT(WARNING) << "Initializer tensor's rank is " << rank << " (expected to be 1 or 2)."; return false; @@ -147,47 +168,118 @@ bool MegatronTransformer::PartitionWeightByColumn(const Graph& graph, const Node return false; } + if (stride > 1){ + LOGS_DEFAULT(WARNING) << "Checkpointing is not currently supported for graphs requiring partitioning of weight with stride > 1"; + } + auto initializer = onnxruntime::make_unique(*tensor_proto, graph.ModelPath()); const float* a_weight = initializer->data(); - std::string new_initializer_name = input_arg.Name() + "_column_rank_" + std::to_string(horizontal_parallel_rank_); + std::string new_initializer_name = original_name + "_column_rank_" + std::to_string(horizontal_parallel_rank_); initializer_partition.set_name(new_initializer_name); initializer_partition.set_data_type(data_type); int64_t column_partition = column_count / horizontal_parallel_size_; int64_t column_stride = column_count / stride; - int64_t column_stride_partition = column_stride / horizontal_parallel_size_; + std::vector new_shape; if (rank == 2) { initializer_partition.add_dims(row_count); + new_shape.push_back(row_count); } initializer_partition.add_dims(column_partition); + new_shape.push_back(column_partition); const int64_t element_count = row_count * column_partition; std::vector result; result.reserve(element_count); - const int64_t stride_partition_column_offset = horizontal_parallel_rank_ * column_stride_partition; - for (auto row_index = 0; row_index < row_count; row_index++) { - auto row_offset = row_index * column_count; - for (auto stride_index = 0; stride_index < stride; stride_index++) { - auto column_offset = row_offset + stride_index * column_stride + stride_partition_column_offset; - std::copy(a_weight + column_offset, a_weight + column_offset + column_stride_partition, std::back_inserter(result)); + PartitionBufferByColumn(a_weight, row_count, column_count, column_stride, stride, result); + initializer_partition.set_raw_data(result.data(), element_count * sizeof(float)); + + // Partition initial optimizer state if available + const auto optim_state_it = initial_optimizer_states_.find(original_name); + if (optim_state_it != initial_optimizer_states_.end()) { + auto& initial_states = optim_state_it->second; + // partition moments same way as the weight + for (const auto& moments_prefix : training::MOMENTS_PREFIXES) { + const auto initial_state_it = initial_states.find(moments_prefix); + if (initial_state_it != initial_states.end()) { + auto* init_tensor = initial_state_it->second.GetMutable(); + + OrtValue partitioned; + auto element_type = init_tensor->DataType(); + TensorShape partition_shape(new_shape); + std::unique_ptr p_tensor; + + if (utils::IsPrimitiveDataType(element_type)) { + float* data_buffer = init_tensor->MutableData(); + + // allocate temporary memory to get the column partitioned state + std::vector result_buffer; + result_buffer.reserve(element_count); + PartitionBufferByColumn(data_buffer, row_count, column_count, column_stride, stride, result_buffer); + + // We need to maintain the initial optimizer states as an OrtValue, + // which is converted eventually to a TensorProto in the optimizer builder + // after Megatron and Zero partitioning. This approach saves CPU memory + // as creating a TensorProto involves a copy, and by delaying the copy until + // after the partitioning results in a smaller copy only for the optimizer + // states currently present on the rank. + // Allocate a new buffer to hold the partitioned optimizer state + // as column partitioning cannot re-use the original + // buffer as it is a non-contiguous read + auto alloc = cpu_execution_provider_ .GetAllocator(0, OrtMemTypeDefault); + p_tensor = onnxruntime::make_unique(element_type, + partition_shape, + alloc); + float* out_buffer = p_tensor->MutableData(); + memcpy(out_buffer, result_buffer.data(), sizeof(float) * element_count); + } else if (utils::IsPrimitiveDataType(element_type)) { + MLFloat16* data_buffer = init_tensor->MutableData(); + + // allocate temporary memory to get the column partitioned state + std::vector result_buffer; + result_buffer.reserve(element_count); + PartitionBufferByColumn(data_buffer, row_count, column_count, column_stride, stride, result_buffer); + + // allocate a new buffer as column partitioning cannot re-use the original + // buffer as it is a non-contiguous read on original buffer + auto alloc = cpu_execution_provider_ .GetAllocator(0, OrtMemTypeDefault); + p_tensor = onnxruntime::make_unique(element_type, + partition_shape, + alloc); + MLFloat16* out_buffer = p_tensor->MutableData(); + memcpy(out_buffer, result_buffer.data(), sizeof(MLFloat16) * element_count); + } else { + ORT_THROW("Unsupported type: ", element_type, "for initial optimizer moments."); + } + partitioned.Init(p_tensor.release(), + DataTypeImpl::GetType(), + DataTypeImpl::GetType()->GetDeleteFunc()); + initial_states[moments_prefix] = std::move(partitioned); + } else { + LOGS_DEFAULT(WARNING) << "Initial value for optimizer state: " << moments_prefix + << " not found for weight: " << original_name; + } } } - initializer_partition.set_raw_data(result.data(), element_count * sizeof(float)); - weight_partition_info_[new_initializer_name].megatron_row_partition = 0; + weight_partition_info_[original_name].megatron_row_partition = 0; + weight_partition_info_[original_name].partition_name = new_initializer_name; + weight_partition_info_[original_name].weight_partitioned = true; + return true; } bool MegatronTransformer::PartitionWeightByRow(const Graph& graph, const NodeArg& input_arg, ONNX_NAMESPACE::TensorProto& initializer_partition) const { + const std::string original_name = input_arg.Name(); const ONNX_NAMESPACE::TensorProto* tensor_proto; - if (!graph.GetInitializedTensor(input_arg.Name(), tensor_proto)) { - LOGS_DEFAULT(WARNING) << "PartitionWeightByRow: " << input_arg.Name() << " is not an initializer"; + if (!graph.GetInitializedTensor(original_name, tensor_proto)) { + LOGS_DEFAULT(WARNING) << "PartitionWeightByRow: " << original_name << " is not an initializer"; return false; } @@ -200,11 +292,11 @@ bool MegatronTransformer::PartitionWeightByRow(const Graph& graph, const NodeArg if (rank == 2 && utils::HasDimValue(shape->dim(0)) && utils::HasDimValue(shape->dim(1))) { row_count = shape->dim(0).dim_value(); column_count = shape->dim(1).dim_value(); - weight_partition_info_[input_arg.Name()].original_dim = std::vector{row_count, column_count}; + weight_partition_info_[original_name].original_dim = std::vector{row_count, column_count}; } else if (rank == 1) { row_count = shape->dim(0).dim_value(); column_count = 1; - weight_partition_info_[input_arg.Name()].original_dim = std::vector{row_count}; + weight_partition_info_[original_name].original_dim = std::vector{row_count}; } else { LOGS_DEFAULT(WARNING) << "Initializer tensor's rank is more than " << rank << " (expected to be 1 or 2)."; @@ -219,16 +311,19 @@ bool MegatronTransformer::PartitionWeightByRow(const Graph& graph, const NodeArg auto initializer = onnxruntime::make_unique(*tensor_proto, graph.ModelPath()); const float* a_weight = initializer->data(); - std::string new_initializer_name = input_arg.Name() + "_row_rank_" + std::to_string(horizontal_parallel_rank_); + std::string new_initializer_name = original_name + "_row_rank_" + std::to_string(horizontal_parallel_rank_); initializer_partition.set_name(new_initializer_name); initializer_partition.set_data_type(data_type); int64_t row_partition = row_count / horizontal_parallel_size_; + std::vector new_shape; initializer_partition.add_dims(row_partition); + new_shape.push_back(row_partition); if (rank == 2) { initializer_partition.add_dims(column_count); + new_shape.push_back(column_count); } const int64_t element_count = row_partition * column_count; @@ -238,7 +333,54 @@ bool MegatronTransformer::PartitionWeightByRow(const Graph& graph, const NodeArg const int64_t row_index_offset = horizontal_parallel_rank_ * row_partition; memcpy(result.data(), a_weight + row_index_offset * column_count, sizeof(float) * element_count); initializer_partition.set_raw_data(result.data(), element_count * sizeof(float)); - weight_partition_info_[new_initializer_name].megatron_row_partition = 1; + + // Partition initial optimizer state if available + const auto optim_state_it = initial_optimizer_states_.find(original_name); + if (optim_state_it != initial_optimizer_states_.end()) { + auto& initial_states = optim_state_it->second; + for (const auto& moments_prefix : training::MOMENTS_PREFIXES) { + const auto initial_state_it = initial_states.find(moments_prefix); + if (initial_state_it != initial_states.end()) { + auto* init_tensor = initial_state_it->second.GetMutable(); + + OrtValue partitioned; + auto element_type = init_tensor->DataType(); + TensorShape partition_shape(new_shape); + const OrtMemoryInfo& info = init_tensor->Location(); + std::unique_ptr p_tensor; + + if (utils::IsPrimitiveDataType(element_type)) { + float* data_buffer = init_tensor->MutableData(); + + p_tensor = onnxruntime::make_unique(element_type, + partition_shape, + data_buffer + row_index_offset * column_count, + info); + } else if (utils::IsPrimitiveDataType(element_type)) { + MLFloat16* data_buffer = init_tensor->MutableData(); + + p_tensor = onnxruntime::make_unique(element_type, + partition_shape, + data_buffer + row_index_offset * column_count, + info); + + } else { + ORT_THROW("Unsupported type: ", element_type, "for initial optimizer moments."); + } + partitioned.Init(p_tensor.release(), + DataTypeImpl::GetType(), + DataTypeImpl::GetType()->GetDeleteFunc()); + initial_states[moments_prefix] = std::move(partitioned); + } else { + LOGS_DEFAULT(WARNING) << "Initial value for optimizer state: " << moments_prefix + << " not found for weight: " << original_name; + } + } + } + + weight_partition_info_[original_name].megatron_row_partition = 1; + weight_partition_info_[original_name].partition_name = new_initializer_name; + weight_partition_info_[original_name].weight_partitioned = true; return true; } @@ -409,14 +551,19 @@ Status MegatronTransformer::TransformBARTMLP(Graph& graph, bool& modified, biasgelu_node.GetOutputEdgesCount() != 1) { return skip_status; } - Node& dropout_node = *graph.GetNode(biasgelu_node.OutputNodesBegin()->Index()); - if (!IsExpectedOpAndProvider(dropout_node, dropout_info, provider_type)) { - return skip_status; + + // Either Dropout->Matmul or just Matmul + Node* dropout_node = nullptr; + Node* next_node = graph.GetNode(biasgelu_node.OutputNodesBegin()->Index()); + if (IsExpectedOpAndProvider(*next_node, dropout_info, provider_type)) { + dropout_node = next_node; + next_node = graph.GetNode(dropout_node->OutputNodesBegin()->Index()); } - Node& matmul2_node = *graph.GetNode(dropout_node.OutputNodesBegin()->Index()); - if (!IsExpectedOpAndProvider(matmul2_node, matmul_info, provider_type)) { + if (!IsExpectedOpAndProvider(*next_node, matmul_info, provider_type)) { return skip_status; } + Node& matmul2_node = *next_node; + Node& add_node = *graph.GetNode(matmul2_node.OutputNodesBegin()->Index()); if (!IsExpectedOpAndProvider(add_node, add_info, provider_type)) { return skip_status; @@ -430,8 +577,11 @@ Status MegatronTransformer::TransformBARTMLP(Graph& graph, bool& modified, return skip_status; } - nodes_to_clear_shape.insert(nodes_to_clear_shape.end(), {&node, second_op, &biasgelu_node, &dropout_node, + nodes_to_clear_shape.insert(nodes_to_clear_shape.end(), {&node, second_op, &biasgelu_node, &matmul2_node, transpose_op_ptr}); + if (dropout_node != nullptr) { + nodes_to_clear_shape.insert(nodes_to_clear_shape.end(), {dropout_node}); + } auto dense_wi_weight_arg = second_op->MutableInputDefs()[0]; ONNX_NAMESPACE::TensorProto dense_wi_weight_initializer_partition; @@ -439,7 +589,7 @@ Status MegatronTransformer::TransformBARTMLP(Graph& graph, bool& modified, return skip_status; } - //since the bias doesnt get transposed, partitioning by col + //since the bias doesn't get transposed, partitioning by col auto dense_wi_bias_arg = biasgelu_node.MutableInputDefs()[1]; ONNX_NAMESPACE::TensorProto dense_wi_bias_initializer_partition; if (!PartitionWeightByColumn(graph, *dense_wi_bias_arg, dense_wi_bias_initializer_partition)) { @@ -468,7 +618,9 @@ Status MegatronTransformer::TransformBARTMLP(Graph& graph, bool& modified, graph.RemoveInitializedTensor(dense_wi_bias_arg->Name()); graph.RemoveInitializedTensor(dense_wo_weight_arg->Name()); - dropout_nodes_to_transform.insert(&dropout_node); + if (dropout_node) { + dropout_nodes_to_transform.insert(dropout_node); + } const std::vector mlp_f_input_defs{node.MutableInputDefs()[0]}; auto mlp_f_type_info = *node.MutableInputDefs()[0]->TypeAsProto(); @@ -1217,15 +1369,17 @@ Status MegatronTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_le .IsOK()); auto& graph_inputs = graph.GetInputs(); - for (auto& node : nodes_to_clear_shape) { - auto& inputs = node->MutableInputDefs(); - for (auto* input : inputs) - if (std::find(graph_inputs.begin(), graph_inputs.end(), input) == graph_inputs.end()) - input->ClearShape(); - - for (auto* output : node->MutableOutputDefs()) - if (std::find(graph_inputs.begin(), graph_inputs.end(), output) == graph_inputs.end()) - output->ClearShape(); + for (auto node : nodes_to_clear_shape) { + if (node != nullptr) { + auto& inputs = node->MutableInputDefs(); + for (auto* input : inputs) + if (std::find(graph_inputs.begin(), graph_inputs.end(), input) == graph_inputs.end()) + input->ClearShape(); + + for (auto* output : node->MutableOutputDefs()) + if (std::find(graph_inputs.begin(), graph_inputs.end(), output) == graph_inputs.end()) + output->ClearShape(); + } } for (auto x : updated_weight_names_) { diff --git a/orttraining/orttraining/core/optimizer/megatron_transformer.h b/orttraining/orttraining/core/optimizer/megatron_transformer.h index 410c24bcff935..d390d319594bc 100644 --- a/orttraining/orttraining/core/optimizer/megatron_transformer.h +++ b/orttraining/orttraining/core/optimizer/megatron_transformer.h @@ -15,13 +15,17 @@ class MegatronTransformer : public GraphTransformer { std::unordered_map& updated_weight_names, std::unordered_set& weights_to_train, std::unordered_map& weight_partition_info, + training::TrainingSession::OptimizerState& initial_optimizer_states, + const IExecutionProvider& cpu_execution_provider , // Required to get allocator for optimizer partitioning by Col const std::unordered_set& compatible_execution_providers = {}) noexcept : GraphTransformer("MegatronTransformer", compatible_execution_providers), horizontal_parallel_rank_(horizontal_parallel_rank), horizontal_parallel_size_(horizontal_parallel_size), updated_weight_names_(updated_weight_names), weights_to_train_(weights_to_train), - weight_partition_info_(weight_partition_info) {} + weight_partition_info_(weight_partition_info), + initial_optimizer_states_(initial_optimizer_states), + cpu_execution_provider_ (cpu_execution_provider ) {} Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; @@ -63,6 +67,14 @@ class MegatronTransformer : public GraphTransformer { std::unordered_set& dropout_nodes_to_transform, int32_t& counter) const; + template + void PartitionBufferByColumn(const T* input, + const int64_t row_count, + const int64_t column_count, + const int64_t column_stride, + const int stride, + std::vector& result) const; + bool PartitionWeightByColumn(const Graph& graph, const NodeArg& input_arg, ONNX_NAMESPACE::TensorProto& initializer_partition, int stride = 1) const; @@ -75,6 +87,8 @@ class MegatronTransformer : public GraphTransformer { std::unordered_map& updated_weight_names_; std::unordered_set& weights_to_train_; std::unordered_map& weight_partition_info_; + training::TrainingSession::OptimizerState& initial_optimizer_states_; + const IExecutionProvider& cpu_execution_provider_ ; }; } // namespace onnxruntime \ No newline at end of file diff --git a/orttraining/orttraining/core/session/training_session.cc b/orttraining/orttraining/core/session/training_session.cc index 2d374745279ba..c4395c2145aec 100644 --- a/orttraining/orttraining/core/session/training_session.cc +++ b/orttraining/orttraining/core/session/training_session.cc @@ -53,6 +53,7 @@ Status SetupOptimizerParams( const std::unordered_map& fp32_weight_names_to_mixed_precision_node_args, const optional& loss_scale_input_name, const TrainingSession::TrainingConfiguration& config, + const TrainingSession::OptimizerState& init_optimizer_states, OptimizerGraphConfig& opt_graph_config_result, std::unordered_map& opt_node_configs_result, std::unordered_map& weight_name_map_after_graph_transform) { @@ -98,12 +99,10 @@ Status SetupOptimizerParams( opt_node_config.mixed_precision_weight_arg = mixed_precision_weight_name_it->second; } - // check if initial optimizer states have been provided for weight - if (config.init_optimizer_states) { - const auto optim_state_it = config.init_optimizer_states->find(weight_name); - if (optim_state_it != config.init_optimizer_states->end()) { - opt_node_config.initial_states = optim_state_it->second; - } + // retrieve value for initial optimizer states if provided for weight + const auto optim_state_it = init_optimizer_states.find(original_weight_name); + if (optim_state_it != init_optimizer_states.end()) { + opt_node_config.initial_states = optim_state_it->second; } opt_node_configs.emplace(weight_name, std::move(opt_node_config)); @@ -130,12 +129,11 @@ Status SetupOptimizerParams( opt_graph_config.deepspeed_zero = optimizer_config.deepspeed_zero; // check if shared initial optimizer states have been provided - if (config.init_optimizer_states) { - const auto optim_state_it = config.init_optimizer_states->find(onnxruntime::training::SHARED_OPTIMIZER_STATES_KEY); - if (optim_state_it != config.init_optimizer_states->end()) { - opt_graph_config.shared_optimizer_states = std::move(optim_state_it->second); - } + const auto optim_state_it = init_optimizer_states.find(onnxruntime::training::SHARED_OPTIMIZER_STATES_KEY); + if (optim_state_it != init_optimizer_states.end()) { + opt_graph_config.shared_optimizer_states = std::move(optim_state_it->second); } + opt_node_configs_result = std::move(opt_node_configs); opt_graph_config_result = std::move(opt_graph_config); @@ -420,6 +418,10 @@ Status TrainingSession::ConfigureForTraining( } } + if (config.init_optimizer_states) { + init_optimizer_states_ = config.init_optimizer_states.value(); + } + ORT_RETURN_IF_ERROR(ApplyTransformationsToMainGraph(trainable_initializers, config.graph_transformer_config)); ORT_RETURN_IF_ERROR(ApplyModelParallelTransformationsToMainGraph(trainable_initializers, config_result)); @@ -501,7 +503,7 @@ Status TrainingSession::ConfigureForTraining( std::unordered_map opt_node_configs{}; ORT_RETURN_IF_ERROR(SetupOptimizerParams( weights_to_train_, fp32_weight_name_to_mixed_precision_node_arg, - loss_scale_input_name, config, opt_graph_config, opt_node_configs, config_result.weight_name_map_after_graph_transform)); + loss_scale_input_name, config, init_optimizer_states_, opt_graph_config, opt_node_configs, config_result.weight_name_map_after_graph_transform)); TrainingConfigurationResult::OptimizerConfigurationResult optimizer_config_result{}; ORT_RETURN_IF_ERROR(BuildOptimizer( opt_graph_config, opt_node_configs, @@ -796,12 +798,16 @@ Status TrainingSession::ApplyModelParallelTransformationsToMainGraph(std::unorde GraphTransformerManager graph_transformation_mgr{1}; std::vector> transformers_to_register; + // Creating the CPU EP here to be used to get the + // CPU allocator for partitioning the optimizer state by column. + std::unique_ptr cpu_execution_provider = + onnxruntime::make_unique(CPUExecutionProviderInfo()); std::unordered_set compatible_eps = {}; LOGS_DEFAULT(WARNING) << horizontal_parallel_size << "-way horizontal model parallel is enabled"; transformers_to_register.emplace_back(onnxruntime::make_unique( training::DistributedRunContext::RankInGroup(training::WorkerGroupType::HorizontalParallel), horizontal_parallel_size, config_result_out.weight_name_map_after_graph_transform, weights_to_train, - config_result_out.weight_partition_info, compatible_eps)); + config_result_out.weight_partition_info, init_optimizer_states_, *cpu_execution_provider, compatible_eps)); // Generate and register transformers for level for (auto& entry : transformers_to_register) { @@ -1082,10 +1088,13 @@ common::Status TrainingSession::GetOptimizerState(std::unordered_mapsecond; - opt_state_tensors.erase(it); + const auto& it = opt_state_tensors.find(weight.second.partition_name); + if (it == opt_state_tensors.end()) { + ORT_RETURN_IF_NOT(allow_missing, "Failed to get optimizer params for partition: " + weight.second.partition_name); + } else { + opt_state_tensors[weight.first] = it->second; + opt_state_tensors.erase(it); + } } return Status::OK(); } @@ -1095,20 +1104,29 @@ common::Status TrainingSession::GetModelState(std::unordered_map fp_tensor_names{}; fp_tensor_names.insert( weights_to_train_.begin(), weights_to_train_.end()); - // Add zero sharded weights, only needed for fp32 weights in mixed precision run - for (const auto& weight_sharded_pair : updated_weight_names_map_) { - fp_tensor_names.erase(weight_sharded_pair.first); // remove the original name - fp_tensor_names.insert(weight_sharded_pair.second); + // Add sharded weights + for (const auto& weight : weight_partition_info_) { + if (weight.second.weight_partitioned) { + fp_tensor_names.erase(weight.first); // remove the original name + fp_tensor_names.insert(weight.second.partition_name); + } } + NameMLValMap fp_weights; GetSessionState().GetInitializedTensors(fp_tensor_names, allow_missing, fp_weights); - // Change key from sharded_name to weight_name - for (const auto& weight_sharded_pair : updated_weight_names_map_) { - const auto& it = fp_weights.find(weight_sharded_pair.second); - ORT_ENFORCE(it != fp_weights.end(), "Cannot find weight: " + weight_sharded_pair.second + " in updated_weight_names_map_"); - fp_weights[weight_sharded_pair.first] = it->second; - fp_weights.erase(it); + // Change key from sharded_name to weight_name using partition_info + for (const auto& weight : weight_partition_info_) { + if (weight.second.weight_partitioned) { + const auto& it = fp_weights.find(weight.second.partition_name); + if (it == fp_weights.end()) { + ORT_RETURN_IF_NOT(allow_missing, "Failed to get weight partition: " + weight.second.partition_name); + } else { + fp_weights[weight.first] = it->second; + fp_weights.erase(it); + } + } } + model_state_tensors["full_precision"] = fp_weights; if (include_mixed_precision_weights) { std::unordered_set mp_tensor_names{}; diff --git a/orttraining/orttraining/core/session/training_session.h b/orttraining/orttraining/core/session/training_session.h index 736127c199486..7b94818fe2e01 100644 --- a/orttraining/orttraining/core/session/training_session.h +++ b/orttraining/orttraining/core/session/training_session.h @@ -32,9 +32,15 @@ class TrainingSession : public InferenceSession { * Partition information of each paritioned weight */ struct PartitionInfo { + // value of the original shape of the weight std::vector original_dim; + // indicates whether weight was megatron partitioned or not. + // -1: not partitioned; 0: column partitioned; 1: row partitioned int megatron_row_partition = -1; - std::string view_name; + // name of the partition used to look up partitioned weight and optimizer state values + std::string partition_name; + // whether the weight itself was paritioned or not(eg:just the optimizer state for fp32 Zero-1) + bool weight_partitioned = false; }; TrainingSession(const SessionOptions& session_options, const Environment& env) @@ -551,6 +557,7 @@ class TrainingSession : public InferenceSession { bool is_configured_{false}; std::unordered_set weights_to_train_; + OptimizerState init_optimizer_states_; // names of additional initializers to be included in checkpoints std::unordered_map updated_weight_names_map_; std::unordered_set opt_state_initializer_names_; diff --git a/orttraining/orttraining/python/training/_utils.py b/orttraining/orttraining/python/training/_utils.py index 7d110f08f737a..035c4211858d5 100644 --- a/orttraining/orttraining/python/training/_utils.py +++ b/orttraining/orttraining/python/training/_utils.py @@ -242,6 +242,16 @@ def state_dict_trainer_options_world_size_key(): return 'world_size' +def state_dict_trainer_options_data_parallel_size_key(): + """Returns the trainer options data_parallel_size key name in the state dictionary""" + + return 'data_parallel_size' + +def state_dict_trainer_options_horizontal_parallel_size_key(): + """Returns the trainer options horizontal_parallel_size key name in the state dictionary""" + + return 'horizontal_parallel_size' + def state_dict_trainer_options_optimizer_name_key(): """Returns the trainer options optimizer_name key name in the state dictionary""" diff --git a/orttraining/orttraining/python/training/checkpoint.py b/orttraining/orttraining/python/training/checkpoint.py index c343d03ee64db..364eac64b0a8f 100644 --- a/orttraining/orttraining/python/training/checkpoint.py +++ b/orttraining/orttraining/python/training/checkpoint.py @@ -3,6 +3,8 @@ import os import torch import warnings +import tempfile +from enum import Enum from . import _checkpoint_storage, _utils @@ -109,8 +111,14 @@ def experimental_load_checkpoint(ort_trainer, checkpoint_dir, checkpoint_prefix= else: return _load_single_checkpoint(ort_trainer, checkpoint_dir, checkpoint_prefix, is_partitioned, strict) -def _order_paths(paths): - """Reorders the given paths in ascending order of rank and return the ordered list""" + +class _AGGREGATION_MODE(Enum): + Zero = 0 + Megatron = 1 + +def _order_paths(paths, D_groups, H_groups): + """Reorders the given paths in order of aggregation of ranks for D and H parallellism respectively + and returns the ordered dict""" trainer_options_path_tuples = [] world_rank = _utils.state_dict_trainer_options_world_rank_key() @@ -119,29 +127,39 @@ def _order_paths(paths): trainer_options_path_tuples.append((_checkpoint_storage.load(path, key=_utils.state_dict_trainer_options_key()), path)) - ordered_paths = [path for _, path in sorted(trainer_options_path_tuples, + # sort paths according to rank + sorted_paths = [path for _, path in sorted(trainer_options_path_tuples, key=lambda trainer_options_path_pair: trainer_options_path_pair[0][world_rank])] + + ordered_paths = dict() + ordered_paths['D'] = [[sorted_paths[i] for i in D_groups[group_id]] for group_id in range(len(D_groups))] + ordered_paths['H'] = [[sorted_paths[i] for i in H_groups[group_id]] for group_id in range(len(H_groups))] return ordered_paths -def _add_or_update_sharded_key_for_zero(state_key, state_value, state_sub_dict, - model_state_key, original_dim, sharded_states_original_dims): +def _add_or_update_sharded_key(state_key, state_value, state_sub_dict, + model_state_key, state_partition_info, sharded_states_original_dims, mode): """Add or update the record for the sharded state_key in the state_sub_dict""" # record the original dimension for this state - sharded_states_original_dims[model_state_key] = original_dim + original_dim = _utils.state_dict_original_dimension_key() + sharded_states_original_dims[model_state_key] = state_partition_info[original_dim] + + axis = 0 + if mode == _AGGREGATION_MODE.Megatron and state_partition_info["megatron_row_partition"] == 0: + axis = -1 if state_key in state_sub_dict: # state_dict already contains a record for this state # since this state is sharded, concatenate the state value to # the record in the state_dict state_sub_dict[state_key] = \ - np.concatenate((state_sub_dict[state_key], state_value)) + np.concatenate((state_sub_dict[state_key], state_value), axis) else: # create a new entry for this state in the state_dict state_sub_dict[state_key] = state_value -def _add_or_validate_unsharded_key_for_zero(state_key, state_value, state_sub_dict, mismatch_error_string): +def _add_or_validate_unsharded_key(state_key, state_value, state_sub_dict, mismatch_error_string): """Add or validate the record for the unsharded state_key in the state_sub_dict""" if state_key in state_sub_dict: @@ -152,13 +170,12 @@ def _add_or_validate_unsharded_key_for_zero(state_key, state_value, state_sub_di # create a new entry for this state in the state_sub_dict state_sub_dict[state_key] = state_value -def _aggregate_model_states(rank_state_dict, sharded_states_original_dims, state_dict, mixed_precision_enabled): +def _aggregate_model_states(rank_state_dict, sharded_states_original_dims, state_dict, mixed_precision_enabled, mode = _AGGREGATION_MODE.Zero): """Aggregates all model states from the rank_state_dict into state_dict""" model = _utils.state_dict_model_key() full_precision = _utils.state_dict_full_precision_key() partition_info = _utils.state_dict_partition_info_key() - original_dim = _utils.state_dict_original_dimension_key() # if there are no model states in the rank_state_dict, no model aggregation is needed if model not in rank_state_dict: @@ -172,24 +189,24 @@ def _aggregate_model_states(rank_state_dict, sharded_states_original_dims, state # iterate over all model state keys for model_state_key, model_state_value in rank_state_dict[model][full_precision].items(): - # full precision model states are sharded only when they exist in the partition_info subdict and mixed + # ZERO: full precision model states are sharded only when they exist in the partition_info subdict and mixed # precision training was enabled. for full precision training, full precision model states are not sharded - if mixed_precision_enabled and (model_state_key in rank_state_dict[partition_info]): - # this model state is sharded since a record exists in the partition_info subdict - _add_or_update_sharded_key_for_zero(model_state_key, model_state_value, + # MEGATRON : full precision model states are sharded when they exist in the partition_info subdict + if (model_state_key in rank_state_dict[partition_info]) and (mode == _AGGREGATION_MODE.Megatron or mixed_precision_enabled): + # this model state is sharded + _add_or_update_sharded_key(model_state_key, model_state_value, state_dict[model][full_precision], model_state_key, - rank_state_dict[partition_info][model_state_key][original_dim], sharded_states_original_dims) + rank_state_dict[partition_info][model_state_key], sharded_states_original_dims, mode) else: # this model state is not sharded since a record for it does not exist in the partition_info subdict - _add_or_validate_unsharded_key_for_zero(model_state_key, model_state_value, + _add_or_validate_unsharded_key(model_state_key, model_state_value, state_dict[model][full_precision], "Value mismatch for model state {}".format(model_state_key)) -def _aggregate_optimizer_states(rank_state_dict, sharded_states_original_dims, state_dict): +def _aggregate_optimizer_states(rank_state_dict, sharded_states_original_dims, state_dict, mode = _AGGREGATION_MODE.Zero): """Aggregates all optimizer states from the rank_state_dict into state_dict""" optimizer = _utils.state_dict_optimizer_key() partition_info = _utils.state_dict_partition_info_key() - original_dim = _utils.state_dict_original_dimension_key() sharded_optimizer_keys = _utils.state_dict_sharded_optimizer_keys() # if there are no optimizer states in the rank_state_dict, no optimizer aggregation is needed @@ -207,13 +224,13 @@ def _aggregate_optimizer_states(rank_state_dict, sharded_states_original_dims, s if optimizer_key in sharded_optimizer_keys and model_state_key in rank_state_dict[partition_info]: # this optimizer state is sharded since a record exists in the partition_info subdict - _add_or_update_sharded_key_for_zero(optimizer_key, optimizer_value, + _add_or_update_sharded_key(optimizer_key, optimizer_value, state_dict[optimizer][model_state_key], model_state_key, - rank_state_dict[partition_info][model_state_key][original_dim], sharded_states_original_dims) + rank_state_dict[partition_info][model_state_key], sharded_states_original_dims, mode) else: # this optimizer state is not sharded since a record for it does not exist in the partition_info subdict # or this optimizer key is not one of the sharded optimizer keys - _add_or_validate_unsharded_key_for_zero(optimizer_key, optimizer_value, + _add_or_validate_unsharded_key(optimizer_key, optimizer_value, state_dict[optimizer][model_state_key], "Value mismatch for model state {} and optimizer state {}".format(model_state_key, optimizer_key)) @@ -237,24 +254,39 @@ def _reshape_states(sharded_states_original_dims, state_dict, mixed_precision_en if optimizer_key in sharded_optimizer_keys: state_dict[optimizer][sharded_state_key][optimizer_key] = optimizer_value.reshape(original_dim) -def _aggregate_trainer_options(rank_state_dict, state_dict): +def _aggregate_trainer_options(rank_state_dict, state_dict, partial_aggregation): """Extracts trainer options from rank_state_dict and loads them accordingly on state_dict""" - - state_dict[_utils.state_dict_trainer_options_key()] = {} + trainer_options = _utils.state_dict_trainer_options_key() + state_dict[trainer_options] = {} mixed_precision = _utils.state_dict_trainer_options_mixed_precision_key() zero_stage = _utils.state_dict_trainer_options_zero_stage_key() world_rank = _utils.state_dict_trainer_options_world_rank_key() world_size = _utils.state_dict_trainer_options_world_size_key() optimizer_name = _utils.state_dict_trainer_options_optimizer_name_key() - - state_dict[_utils.state_dict_trainer_options_key()][mixed_precision] = \ - rank_state_dict[_utils.state_dict_trainer_options_key()][mixed_precision] - state_dict[_utils.state_dict_trainer_options_key()][zero_stage] = 0 - state_dict[_utils.state_dict_trainer_options_key()][world_rank] = 0 - state_dict[_utils.state_dict_trainer_options_key()][world_size] = 1 - state_dict[_utils.state_dict_trainer_options_key()][optimizer_name] = \ - rank_state_dict[_utils.state_dict_trainer_options_key()][optimizer_name] + D_size = _utils.state_dict_trainer_options_data_parallel_size_key() + H_size = _utils.state_dict_trainer_options_horizontal_parallel_size_key() + + state_dict[trainer_options][mixed_precision] = rank_state_dict[trainer_options][mixed_precision] + state_dict[trainer_options][zero_stage] = 0 + state_dict[trainer_options][world_rank] = rank_state_dict[trainer_options][world_rank] if partial_aggregation else 0 + state_dict[trainer_options][world_size] = 1 + state_dict[trainer_options][optimizer_name] = rank_state_dict[trainer_options][optimizer_name] + state_dict[trainer_options][D_size] = 1 + state_dict[trainer_options][H_size] = 1 + +def _aggregate_megatron_partition_info(rank_state_dict, state_dict): + """Extracts partition_info from rank_state_dict and loads on state_dict for megatron-partitioned weights""" + partition_info = _utils.state_dict_partition_info_key() + if partition_info not in state_dict: + state_dict[partition_info] = {} + + rank_partition_info = rank_state_dict[partition_info] + for model_state_key, partition_info_dict in rank_partition_info.items(): + if model_state_key not in state_dict[partition_info]: + # add partition info only if weight is megatron partitioned + if (partition_info_dict["megatron_row_partition"] >= 0): + state_dict[partition_info][model_state_key] = partition_info_dict def _to_pytorch_format(state_dict): """Convert ORT state dictionary schema (hierarchical structure) to PyTorch state dictionary schema (flat structure)""" @@ -266,26 +298,44 @@ def _to_pytorch_format(state_dict): pytorch_state_dict[model_state_key] = torch.tensor(model_state_value) return pytorch_state_dict -def aggregate_checkpoints(paths, pytorch_format=True): - """Aggregate checkpoint files and return a single state dictionary - - Aggregates checkpoint files specified by paths and laods the checkpoint file one at a time merging - them into a single state dictionary. - The checkpoint files represented by paths must be saved through ORTTrainer.save_checkpoint() function. - The schema of the state_dict returned will be in the same as the one returned by ORTTrainer.state_dict() +def _get_parallellism_groups(data_parallel_size, horizontal_parallel_size, world_size): + """Returns the D and H groups for the given sizes""" + num_data_groups = world_size // data_parallel_size + data_groups = [] + for data_group_id in range(num_data_groups): + data_group_ranks=[] + for r in range(data_parallel_size): + data_group_ranks.append(data_group_id + horizontal_parallel_size * r) + data_groups.append(data_group_ranks) + + num_horizontal_groups = world_size // horizontal_parallel_size + horizontal_groups = [] + for hori_group_id in range(num_horizontal_groups): + hori_group_ranks=[] + for r in range(horizontal_parallel_size): + hori_group_ranks.append(hori_group_id * horizontal_parallel_size + r) + horizontal_groups.append(hori_group_ranks) + + return data_groups, horizontal_groups + +def _aggregate_over_ranks(ordered_paths, ranks, sharded_states_original_dims = None, mode = _AGGREGATION_MODE.Zero, partial_aggregation = False, pytorch_format=True): + """Aggregate checkpoint files over set of ranks and return a single state dictionary Args: - paths: list of more than one file represented as strings where the checkpoint is saved + ordered_paths: list of paths in the order in which they must be aggregated + ranks: list of ranks that are to be aggregated + sharded_states_original_dims: dict containing the original dims for sharded states that are persisted over + multiple calls to _aggregate_over_ranks() + mode: mode of aggregation: Zero or Megatron + partial_aggregation: boolean flag to indicate whether to produce a partially + aggregated state which can be further aggregated over pytorch_format: boolean flag to select either ONNX Runtime or PyTorch state schema of the returned state_dict Returns: state_dict that can be loaded into an ORTTrainer or into a PyTorch model """ - - # order the paths in ascending order of ranks - ordered_paths = _order_paths(paths) - state_dict = {} - sharded_states_original_dims = {} + if sharded_states_original_dims is None: + sharded_states_original_dims = dict() world_rank = _utils.state_dict_trainer_options_world_rank_key() mixed_precision = _utils.state_dict_trainer_options_mixed_precision_key() zero_stage = _utils.state_dict_trainer_options_zero_stage_key() @@ -297,12 +347,12 @@ def aggregate_checkpoints(paths, pytorch_format=True): loaded_zero_stage = None loaded_optimizer_name = None - for rank, path in enumerate(ordered_paths): + for i, path in enumerate(ordered_paths): rank_state_dict = _checkpoint_storage.load(path) assert _utils.state_dict_partition_info_key() in rank_state_dict, "Missing information: partition_info" assert _utils.state_dict_trainer_options_key() in rank_state_dict, "Missing information: trainer_options" - assert rank == rank_state_dict[_utils.state_dict_trainer_options_key()][world_rank], \ + assert ranks[i] == rank_state_dict[_utils.state_dict_trainer_options_key()][world_rank], \ "Unexpected rank in file at path {}. Expected {}, got {}".\ format(path, rank, rank_state_dict[_utils.state_dict_trainer_options_key()][world_rank]) if loaded_mixed_precision is None: @@ -327,28 +377,107 @@ def aggregate_checkpoints(paths, pytorch_format=True): "Optimizer name mismatch among checkpoint files. File: {}".format(path) # aggregate all model states - _aggregate_model_states(rank_state_dict, sharded_states_original_dims, state_dict, loaded_mixed_precision) + _aggregate_model_states(rank_state_dict, sharded_states_original_dims, state_dict, loaded_mixed_precision, mode) if not pytorch_format: # aggregate all optimizer states if pytorch_format is False - _aggregate_optimizer_states(rank_state_dict, sharded_states_original_dims, state_dict) + _aggregate_optimizer_states(rank_state_dict, sharded_states_original_dims, state_dict, mode) + + # for D+H aggregation scenario, the first pass of aggregation(partial aggregation) is over D groups + # to aggregate over Zero, and another pass to aggregate Megatron partitioned + # states. Preserve the relevant partition info only for weights that are megatron partitioned for + # a partial aggregation call + if partial_aggregation: + _aggregate_megatron_partition_info(rank_state_dict, state_dict) # entry for trainer_options in the state_dict to perform other sanity checks if _utils.state_dict_trainer_options_key() not in state_dict: - _aggregate_trainer_options(rank_state_dict, state_dict) + _aggregate_trainer_options(rank_state_dict, state_dict, partial_aggregation) # entry for user_dict in the state_dict if not already present if _utils.state_dict_user_dict_key() not in state_dict and \ _utils.state_dict_user_dict_key() in rank_state_dict: state_dict[_utils.state_dict_user_dict_key()] = rank_state_dict[_utils.state_dict_user_dict_key()] - # reshape all the sharded tensors based on the original dimensions stored in sharded_states_original_dims - _reshape_states(sharded_states_original_dims, state_dict, loaded_mixed_precision) + # for a partial aggregation scenario, we might not have the entire tensor aggregated yet, thus skip reshape + if not partial_aggregation: + # reshape all the sharded tensors based on the original dimensions stored in sharded_states_original_dims + _reshape_states(sharded_states_original_dims, state_dict, loaded_mixed_precision) # return a flat structure for PyTorch model in case pytorch_format is True # else return the hierarchical structure for ORTTrainer return _to_pytorch_format(state_dict) if pytorch_format else state_dict +def _aggregate_over_D_H(ordered_paths, D_groups, H_groups, pytorch_format): + """Aggregate checkpoint files and return a single state dictionary for the D+H + (Zero+Megatron) partitioning strategy. + For D+H aggregation scenario, the first pass of aggregation(partial aggregation) is over D groups + to aggregate over Zero, and another pass over the previously aggregated states + to aggregate Megatron partitioned states. + """ + sharded_states_original_dims = {} + aggregate_data_checkpoint_files = [] + + # combine for Zero over data groups and save to temp file + with tempfile.TemporaryDirectory() as save_dir: + for group_id, d_group in enumerate(D_groups): + aggregate_state_dict = _aggregate_over_ranks(ordered_paths['D'][group_id], d_group, sharded_states_original_dims, partial_aggregation = True, pytorch_format=False) + + filename = 'ort.data_group.' + str(group_id) + '.ort.pt' + filepath = os.path.join(save_dir, filename) + _checkpoint_storage.save(aggregate_state_dict, filepath) + aggregate_data_checkpoint_files.append(filepath) + + assert len(aggregate_data_checkpoint_files) > 0 + + # combine for megatron: + aggregate_state = _aggregate_over_ranks(aggregate_data_checkpoint_files, H_groups[0], sharded_states_original_dims, mode = _AGGREGATION_MODE.Megatron, pytorch_format = pytorch_format) + + return aggregate_state + +def aggregate_checkpoints(paths, pytorch_format=True): + """Aggregate checkpoint files and return a single state dictionary + + Aggregates checkpoint files specified by paths and loads them one at a time, merging + them into a single state dictionary. + The checkpoint files represented by paths must be saved through ORTTrainer.save_checkpoint() function. + The schema of the state_dict returned will be in the same as the one returned by ORTTrainer.state_dict() + + Args: + paths: list of more than one file represented as strings where the checkpoint is saved + pytorch_format: boolean flag to select either ONNX Runtime or PyTorch state schema of the returned state_dict + Returns: + state_dict that can be loaded into an ORTTrainer or into a PyTorch model + """ + + loaded_trainer_options = _checkpoint_storage.load(paths[0], key=_utils.state_dict_trainer_options_key()) + D_size = _utils.state_dict_trainer_options_data_parallel_size_key() + H_size = _utils.state_dict_trainer_options_horizontal_parallel_size_key() + world_size = _utils.state_dict_trainer_options_world_size_key() + + D_size = loaded_trainer_options[D_size] + H_size = loaded_trainer_options[H_size] + world_size = loaded_trainer_options[world_size] + D_groups, H_groups = _get_parallellism_groups(D_size, H_size, world_size) + + combine_zero = loaded_trainer_options[_utils.state_dict_trainer_options_zero_stage_key()] > 0 + combine_megatron = len(H_groups[0]) > 1 + + # order the paths in the order of groups in which they must be aggregated according to + # data-parallel groups and H-parallel groups obtained + # eg: {'D': [[path_0, path_2],[path_1, path_3]], 'H': [[path_0, path_1],[path_2, path_3]]} + ordered_paths = _order_paths(paths, D_groups, H_groups) + + aggregate_state = None + if combine_zero and combine_megatron: + aggregate_state = _aggregate_over_D_H(ordered_paths, D_groups, H_groups, pytorch_format) + elif combine_zero: + aggregate_state = _aggregate_over_ranks(ordered_paths['D'][0], D_groups[0], mode = _AGGREGATION_MODE.Zero, pytorch_format = pytorch_format) + elif combine_megatron: + aggregate_state = _aggregate_over_ranks(ordered_paths['H'][0], H_groups[0], mode = _AGGREGATION_MODE.Megatron, pytorch_format = pytorch_format) + + return aggregate_state + ################################################################################ # Helper functions ################################################################################ diff --git a/orttraining/orttraining/python/training/orttrainer.py b/orttraining/orttraining/python/training/orttrainer.py index 9a2cf22ec8a4d..6e321d965a932 100644 --- a/orttraining/orttraining/python/training/orttrainer.py +++ b/orttraining/orttraining/python/training/orttrainer.py @@ -610,6 +610,9 @@ def _create_ort_training_session(self, optimizer_state_dict={}): else: raise ValueError("Optimizer attributes must be either float or int.") + self.options.distributed.horizontal_parallel_size = max(self.options.distributed.horizontal_parallel_size, 1) + self.options.distributed.data_parallel_size = self.options.distributed.world_size // self.options.distributed.horizontal_parallel_size + # TrainingParameters ort_parameters = ort.TrainingParameters() ort_parameters.loss_output_name = loss_name @@ -922,6 +925,8 @@ def _extract_trainer_options(self, state_dict): world_rank = _utils.state_dict_trainer_options_world_rank_key() world_size = _utils.state_dict_trainer_options_world_size_key() optimizer_name = _utils.state_dict_trainer_options_optimizer_name_key() + D_size = _utils.state_dict_trainer_options_data_parallel_size_key() + H_size = _utils.state_dict_trainer_options_horizontal_parallel_size_key() state_dict[_utils.state_dict_trainer_options_key()] = {} state_dict[_utils.state_dict_trainer_options_key()][mixed_precision] = self.options.mixed_precision.enabled @@ -930,6 +935,8 @@ def _extract_trainer_options(self, state_dict): state_dict[_utils.state_dict_trainer_options_key()][world_rank] = self.options.distributed.world_rank state_dict[_utils.state_dict_trainer_options_key()][world_size] = self.options.distributed.world_size state_dict[_utils.state_dict_trainer_options_key()][optimizer_name] = self.optim_config.name + state_dict[_utils.state_dict_trainer_options_key()][D_size] = self.options.distributed.data_parallel_size + state_dict[_utils.state_dict_trainer_options_key()][H_size] = self.options.distributed.horizontal_parallel_size def state_dict(self, pytorch_format=False): """Returns a dictionary with model, and optionally, optimizer states @@ -1024,6 +1031,14 @@ def state_dict(self, pytorch_format=False): "optimizer_name": { type: str + }, + "data_parallel_size": + { + type: int + }, + "horizontal_parallel_size": + { + type: int } } }, @@ -1041,6 +1056,10 @@ def state_dict(self, pytorch_format=False): "original_dim": { type: array + }, + "megatron_row_partition": + { + type: int } } } @@ -1075,6 +1094,8 @@ def state_dict(self, pytorch_format=False): if pytorch_format: if self.options.distributed.deepspeed_zero_optimization.stage > 0: warnings.warn("Incomplete state_dict: ZeRO enabled", UserWarning) + if self.options.distributed.horizontal_parallel_size > 1: + warnings.warn("Incomplete state_dict: Megatron enabled", UserWarning) # if pytorch_format is true, return a flat dictionary with only model states # which is compatible with a PyTorch model return state_dict[_utils.state_dict_model_key()][_utils.state_dict_full_precision_key()] @@ -1086,7 +1107,7 @@ def state_dict(self, pytorch_format=False): self._extract_trainer_options(state_dict) # add partition information in case of a distributed run - if self.options.distributed.deepspeed_zero_optimization.stage > 0: + if self.options.distributed.deepspeed_zero_optimization.stage > 0 or self.options.distributed.horizontal_parallel_size > 1: state_dict[_utils.state_dict_partition_info_key()] = self._training_session.get_partition_info_map() return state_dict @@ -1162,7 +1183,7 @@ def _load_state_dict_impl(self, state_dict, strict=True): # clear the callable partial self._load_state_dict = None - def _mismatch_keys(keys1, keys2, in_error_str): + def _mismatch_keys(keys1, keys2, in_error_str, allow_unexpected=False): """Find out the missing and the unexpected keys in two dictionaries Throws a runtime error if missing or unexpected keys are found @@ -1175,48 +1196,48 @@ def _mismatch_keys(keys1, keys2, in_error_str): unexpected_keys = list(keys2 - keys1) if len(missing_keys) > 0: raise RuntimeError("Missing keys: {} in {}".format(missing_keys, in_error_str)) - if len(unexpected_keys) > 0: + if len(unexpected_keys) > 0 and not allow_unexpected: raise RuntimeError("Unexpected keys: {} in {}".format(unexpected_keys, in_error_str)) - def _check_model_key_mismatch(current_state_dict, state_dict): + def _check_model_key_mismatch(current_state_dict, state_dict, allow_unexpected=False): """Check if there is any mismatch in the model sub state dictionary between the two state_dicts""" # check unxexpected and missing precision keys in the model state_dict compared to the training # session model state_dict _mismatch_keys(current_state_dict[_utils.state_dict_model_key()], - state_dict[_utils.state_dict_model_key()], 'state_dict[model]') + state_dict[_utils.state_dict_model_key()], 'state_dict[model]', allow_unexpected) # check for model state key mismatch for precision_key in current_state_dict[_utils.state_dict_model_key()]: _mismatch_keys(current_state_dict[_utils.state_dict_model_key()][precision_key], state_dict[_utils.state_dict_model_key()][precision_key], - 'state_dict[model][{}]'.format(precision_key)) + 'state_dict[model][{}]'.format(precision_key), allow_unexpected) - def _check_optimizer_key_mismatch(current_state_dict, state_dict): + def _check_optimizer_key_mismatch(current_state_dict, state_dict, allow_unexpected=False): """Check if there is any mismatch in the optimizer sub state dictionary between the two state_dicts""" # check for model state key mismatch for the optimizer state_dict _mismatch_keys(current_state_dict[_utils.state_dict_optimizer_key()], state_dict[_utils.state_dict_optimizer_key()], - 'state_dict[optimizer]') + 'state_dict[optimizer]', allow_unexpected) # check for optimizer state keys mismatch for model_state_key in current_state_dict[_utils.state_dict_optimizer_key()]: _mismatch_keys(current_state_dict[_utils.state_dict_optimizer_key()][model_state_key], state_dict[_utils.state_dict_optimizer_key()][model_state_key], - 'state_dict[optimizer][{}]'.format(model_state_key)) + 'state_dict[optimizer][{}]'.format(model_state_key), allow_unexpected) - def _check_key_mismatch(current_state_dict, state_dict): + def _check_key_mismatch(current_state_dict, state_dict, allow_unexpected=False): """Check if there is a mismatch in the keys (model and optimizer) in the two state_dicts""" # check presence of 'model' in the input state_dict if _utils.state_dict_model_key() in state_dict: - _check_model_key_mismatch(current_state_dict, state_dict) + _check_model_key_mismatch(current_state_dict, state_dict, allow_unexpected) else: warnings.warn("Missing key: model in state_dict", UserWarning) # check presence of 'optimizer' in the input state_dict if _utils.state_dict_optimizer_key() in state_dict: - _check_optimizer_key_mismatch(current_state_dict, state_dict) + _check_optimizer_key_mismatch(current_state_dict, state_dict, allow_unexpected) else: warnings.warn("Missing key: optimizer in state_dict", UserWarning) @@ -1228,7 +1249,10 @@ def _check_key_mismatch(current_state_dict, state_dict): if self._training_session: current_state_dict = self.state_dict() if strict: - _check_key_mismatch(current_state_dict, state_dict) + # for Zero enabled, the current trainer might not have the complete state, and we must allow + # extra keys to be present in the state dict + allow_unexpected = True if self.options.distributed.deepspeed_zero_optimization.stage > 0 else False + _check_key_mismatch(current_state_dict, state_dict, allow_unexpected) # load the model states from the input state dictionary into the onnx graph self._load_model_states(state_dict, strict) @@ -1301,8 +1325,10 @@ def save_checkpoint(self, path, user_dict={}, include_optimizer_states=True): def _aggregation_required(self, loaded_trainer_options): """Checks if aggregation is required for the loading the state_dict into the ORTTrainer""" - # To load states in the backend, aggregation is required for every ZeRO checkpoint - return loaded_trainer_options[_utils.state_dict_trainer_options_zero_stage_key()] > 0 + # To load states in the backend, aggregation is required for every ZeRO + # or Megatron checkpoint + return loaded_trainer_options[_utils.state_dict_trainer_options_zero_stage_key()] > 0 or \ + loaded_trainer_options[_utils.state_dict_trainer_options_horizontal_parallel_size_key()] > 1 def load_checkpoint(self, *paths, strict=True): """Loads the saved checkpoint state dictionary into the ORTTrainer diff --git a/orttraining/orttraining/python/training/orttrainer_options.py b/orttraining/orttraining/python/training/orttrainer_options.py index 422420b4bbb13..04f7e62ca96bb 100644 --- a/orttraining/orttraining/python/training/orttrainer_options.py +++ b/orttraining/orttraining/python/training/orttrainer_options.py @@ -290,7 +290,7 @@ class ORTTrainerOptions(object): distributed (dict): distributed training options. distributed.world_rank (int, default is 0): - rank ID used for data parallelism + rank ID used for data/horizontal parallelism distributed.world_size (int, default is 1): number of ranks participating in parallelism distributed.data_parallel_size (int, default is 1): diff --git a/orttraining/orttraining/test/optimizer/graph_transform_test.cc b/orttraining/orttraining/test/optimizer/graph_transform_test.cc index 2e0c08eee7cdd..d30e6aed6e374 100644 --- a/orttraining/orttraining/test/optimizer/graph_transform_test.cc +++ b/orttraining/orttraining/test/optimizer/graph_transform_test.cc @@ -167,7 +167,9 @@ TEST_F(GraphTransformationTests, MegatronMLPPartitionRank0) { std::unordered_map updated_weight_names; std::unordered_set weights_to_train; std::unordered_map weight_partition_info; - graph_transformation_mgr.Register(onnxruntime::make_unique(0, 2, updated_weight_names, weights_to_train, weight_partition_info), TransformerLevel::Level1); + training::TrainingSession::OptimizerState init_optim_state; + IExecutionProvider* e = TestCPUExecutionProvider(); + graph_transformation_mgr.Register(onnxruntime::make_unique(0, 2, updated_weight_names, weights_to_train, weight_partition_info, init_optim_state, *e), TransformerLevel::Level1); ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_); ASSERT_TRUE(ret.IsOK()); @@ -238,7 +240,9 @@ TEST_F(GraphTransformationTests, MegatronMLPPartitionRank1) { std::unordered_map updated_weight_names; std::unordered_set weights_to_train; std::unordered_map weight_partition_info; - graph_transformation_mgr.Register(onnxruntime::make_unique(1, 2, updated_weight_names, weights_to_train, weight_partition_info), TransformerLevel::Level1); + training::TrainingSession::OptimizerState init_optim_state; + IExecutionProvider* e = TestCPUExecutionProvider(); + graph_transformation_mgr.Register(onnxruntime::make_unique(1, 2, updated_weight_names, weights_to_train, weight_partition_info, init_optim_state, *e), TransformerLevel::Level1); ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_); ASSERT_TRUE(ret.IsOK()); @@ -308,7 +312,9 @@ TEST_F(GraphTransformationTests, MegatronSelfAttentionPartitionRank0) { std::unordered_map updated_weight_names; std::unordered_set weights_to_train; std::unordered_map weight_partition_info; - graph_transformation_mgr.Register(onnxruntime::make_unique(0, 2, updated_weight_names, weights_to_train, weight_partition_info), TransformerLevel::Level1); + training::TrainingSession::OptimizerState init_optim_state; + IExecutionProvider* e = TestCPUExecutionProvider(); + graph_transformation_mgr.Register(onnxruntime::make_unique(0, 2, updated_weight_names, weights_to_train, weight_partition_info, init_optim_state, *e), TransformerLevel::Level1); ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_); ASSERT_TRUE(ret.IsOK()); @@ -376,7 +382,9 @@ TEST_F(GraphTransformationTests, MegatronSelfAttentionPartitionRank1) { std::unordered_map updated_weight_names; std::unordered_set weights_to_train; std::unordered_map weight_partition_info; - graph_transformation_mgr.Register(onnxruntime::make_unique(1, 2, updated_weight_names, weights_to_train, weight_partition_info), TransformerLevel::Level1); + training::TrainingSession::OptimizerState init_optim_state; + IExecutionProvider* e = TestCPUExecutionProvider(); + graph_transformation_mgr.Register(onnxruntime::make_unique(1, 2, updated_weight_names, weights_to_train, weight_partition_info, init_optim_state, *e), TransformerLevel::Level1); ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_); ASSERT_TRUE(ret.IsOK()); @@ -479,7 +487,9 @@ static void RunPartitionCorrectnessTest(std::string model_path, std::unordered_map updated_weight_names; std::unordered_set weights_to_train; std::unordered_map weight_partition_info; - graph_transformation_mgr.Register(onnxruntime::make_unique(i, total_rank, updated_weight_names, weights_to_train, weight_partition_info), TransformerLevel::Level1); + training::TrainingSession::OptimizerState init_optim_state; + IExecutionProvider* e = TestCPUExecutionProvider(); + graph_transformation_mgr.Register(onnxruntime::make_unique(i, total_rank, updated_weight_names, weights_to_train, weight_partition_info, init_optim_state, *e), TransformerLevel::Level1); ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, logger); ORT_ENFORCE(ret.IsOK()); graphs.push_back(&graph); diff --git a/orttraining/orttraining/test/python/_test_commons.py b/orttraining/orttraining/test/python/_test_commons.py index f666b214d5683..325047f70ff41 100644 --- a/orttraining/orttraining/test/python/_test_commons.py +++ b/orttraining/orttraining/test/python/_test_commons.py @@ -8,6 +8,7 @@ import torch import onnx +import onnxruntime from onnxruntime.training import optim, _utils def _single_run(execution_file, scenario, checkopint_dir = None): @@ -20,7 +21,7 @@ def _single_run(execution_file, scenario, checkopint_dir = None): def _distributed_run(execution_file, scenario, checkopint_dir = None): ngpus = torch.cuda.device_count() - cmd = ['mpirun', '-n', str(ngpus), '-x', 'NCCL_DEBUG=INFO', sys.executable, execution_file] + cmd = ['mpirun', '-n', str(ngpus), '-x', 'NCCL_DEBUG=INFO', '--tag-output', sys.executable, execution_file] if scenario: cmd += ['--scenario', scenario] if checkopint_dir: @@ -167,6 +168,39 @@ def _load_pytorch_transformer_model(device, dynamic_axes=False, legacy_api=False train_data, val_data, test_data = utils.prepare_data(device, 20, 20) return model, model_desc, my_loss, utils.get_batch, train_data, val_data, test_data +def generate_random_input_from_bart_model_desc(desc, seed=1, device = "cuda:0"): + '''Generates a sample input for the BART model using the model desc''' + + torch.manual_seed(seed) + onnxruntime.set_seed(seed) + dtype = torch.int64 + vocab_size = 30528 + sample_input = [] + for index, input in enumerate(desc['inputs']): + size = [] + for s in input[1]: + if isinstance(s, (int)): + size.append(s) + else: + size.append(1) + sample_input.append(torch.randint(0, vocab_size, tuple(size), dtype=dtype).to(device)) + return sample_input + +def _load_bart_model(): + bart_onnx_model_path = os.path.join('testdata', "bart_tiny.onnx") + model = onnx.load(bart_onnx_model_path) + batch = 2 + seq_len = 1024 + model_desc = { + 'inputs': [ + ('src_tokens', [batch, seq_len],), + ('prev_output_tokens', [batch, seq_len],), + ('target', [batch*seq_len],)], + 'outputs': [ + ('loss', [], True)]} + + return model, model_desc + def assert_all_states_close_ort(state_dict_pre_checkpoint, state_dict_post_checkpoint, reshape_states=False): """Assert that the two ORTTrainer (hierarchical) state dictionaries are very close for all states""" diff --git a/orttraining/orttraining/test/python/checkpoint/_test_helpers.py b/orttraining/orttraining/test/python/checkpoint/_test_helpers.py index 5eede2b3810e3..734ab53e48349 100644 --- a/orttraining/orttraining/test/python/checkpoint/_test_helpers.py +++ b/orttraining/orttraining/test/python/checkpoint/_test_helpers.py @@ -2,15 +2,19 @@ import pickle from itertools import islice import glob +import copy import torch import torch.distributed as dist +from onnx import numpy_helper from onnxruntime import set_seed -from onnxruntime.training import amp, checkpoint, optim, orttrainer +from onnxruntime.training import amp, checkpoint, _checkpoint_storage, optim, orttrainer from onnxruntime.capi._pybind_state import set_cuda_device_id, get_mpi_context_world_rank, get_mpi_context_world_size -from _test_commons import generate_dummy_optim_state, _load_pytorch_transformer_model, assert_all_states_close_ort, assert_all_states_close_pytorch +from _test_commons import generate_random_input_from_bart_model_desc, generate_dummy_optim_state, \ + _load_pytorch_transformer_model, _load_bart_model, \ + assert_all_states_close_ort, assert_all_states_close_pytorch from numpy.testing import assert_allclose, assert_array_equal @@ -40,6 +44,9 @@ def _save(trainer, checkpoint_dir, state_dict_key_name, world_rank=None): with open(os.path.join(checkpoint_dir, state_dict_key_name+'.pkl'), "wb") as f: pickle.dump({state_dict_key_name : state_dict}, f) +def save_ort_ckpt(state_dict, filepath): + _checkpoint_storage.save(state_dict, filepath) + def _chunkify(sequence, num_chunks): """Breaks down a given sequence into num_chunks chunks""" quo, rem = divmod(len(sequence), num_chunks) @@ -56,6 +63,11 @@ def _setup_test_infra(world_rank, world_size): dist.init_process_group(backend='nccl', world_size=world_size, rank=world_rank) +def _is_model_parallel_run(trainer_options): + zero = trainer_options.distributed.deepspeed_zero_optimization.stage > 0 + megatron = trainer_options.distributed.horizontal_parallel_size > 1 + return zero or megatron + def distributed_setup(func): """Decorator function for distributed tests. @@ -77,7 +89,7 @@ def setup(checkpoint_dir): return setup -def create_orttrainer_and_load_checkpoint(device, trainer_opts, checkpoint_dir, use_lamb=True): +def create_orttrainer_and_load_checkpoint(device, trainer_opts, checkpoint_dir, use_lamb=True, seed=1, learning_rate=0.1): """Instantiate and load checkpoint into trainer - Instantiates the ORTTrainer with given input trainer_opts configuration for a simple transformer model @@ -85,12 +97,10 @@ def create_orttrainer_and_load_checkpoint(device, trainer_opts, checkpoint_dir, - Runs eval_step on the trainer so the trainer onnx graph is initialized - Returns the trainer state_dict and the pytorch model """ - seed = 1 torch.manual_seed(seed) set_seed(seed) - # PyTorch transformer model setup - learning_rate = 0.1 + # PyTorch transformer model setup optim_config = optim.LambConfig(lr=learning_rate) if use_lamb else optim.AdamConfig(lr=learning_rate) model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(device) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=loss_fn, options=orttrainer.ORTTrainerOptions(trainer_opts)) @@ -108,12 +118,43 @@ def create_orttrainer_and_load_checkpoint(device, trainer_opts, checkpoint_dir, return trainer.state_dict(), model -def create_initialized_orttrainer(device, trainer_opts, use_lamb=True): - seed = 1 +def create_orttrainer_and_load_checkpoint_bart(device, trainer_opts, checkpoint_dir, use_lamb=True, seed=1, learning_rate=0.1): + """Instantiate and load checkpoint into trainer + + - Instantiates the ORTTrainer with given input trainer_opts configuration for a simple BART model + - Loads the checkpoint from directory checkpoint_dir into the trainer + - Runs eval_step on the trainer so the trainer onnx graph is initialized + - Returns the trainer state_dict, the expected state dict if present, and the onnx model + """ + torch.manual_seed(seed) + set_seed(seed) + + # model setup + optim_config = optim.LambConfig(lr=learning_rate) if use_lamb else optim.AdamConfig(lr=learning_rate) + model, model_desc = _load_bart_model() + trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=orttrainer.ORTTrainerOptions(trainer_opts)) + + # load checkpoint into trainer + checkpoint_file_name = 'checkpoint*.ortcp' + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, checkpoint_file_name)) + trainer.load_checkpoint(*checkpoint_files) + + # run an eval step to innitialize the graph + src_tokens, prev_output_tokens, target = generate_random_input_from_bart_model_desc(model_desc, seed = seed) + trainer.eval_step(src_tokens, prev_output_tokens, target) + + expected_state_dict = None + fname = os.path.join(checkpoint_dir, 'expected_state_dict.pkl') + if os.path.isfile(fname): + with open(fname, "rb") as f: + expected_state_dict = pickle.load(f) + + return trainer.state_dict(), expected_state_dict, model + +def create_initialized_orttrainer(device, trainer_opts, use_lamb=True, seed=1, learning_rate=1e-10): torch.manual_seed(seed) set_seed(seed) - learning_rate = 1e-10 optim_config = optim.LambConfig(lr=learning_rate) if use_lamb else optim.AdamConfig(lr=learning_rate) model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(device) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=loss_fn, options=orttrainer.ORTTrainerOptions(trainer_opts)) @@ -243,16 +284,14 @@ def aggregate_states(checkpoint_dir, filename_prefix='state_dict', state_dict_ke return aggregated_states -def create_orttrainer_and_save_checkpoint(device, trainer_opts, checkpoint_dir, state_dict_key_name='state_dict', use_lamb=True): - learning_rate = 0.1 - seed = 1 - +def create_orttrainer_and_save_checkpoint(device, trainer_opts, checkpoint_dir, state_dict_key_name='state_dict', use_lamb=True, seed=1, learning_rate=0.1): torch.manual_seed(seed) set_seed(seed) + ort_trainer_opts = orttrainer.ORTTrainerOptions(trainer_opts) optim_config = optim.LambConfig(lr=learning_rate) if use_lamb else optim.AdamConfig(lr=learning_rate) model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(device) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=loss_fn, options=orttrainer.ORTTrainerOptions(trainer_opts)) + trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=loss_fn, options=ort_trainer_opts) if 'distributed' in trainer_opts: train_data = next(islice(_chunkify(train_data, trainer_opts['distributed']['world_size']), trainer_opts['distributed']['world_rank'], None)) @@ -262,15 +301,51 @@ def create_orttrainer_and_save_checkpoint(device, trainer_opts, checkpoint_dir, # save current model parameters as a checkpoint if checkpoint_dir: - if 'distributed' in trainer_opts and 'deepspeed_zero_optimization' in trainer_opts['distributed']: - _save(trainer, checkpoint_dir, state_dict_key_name, world_rank=trainer_opts['distributed']['world_rank']) + if _is_model_parallel_run(ort_trainer_opts): + _save(trainer, checkpoint_dir, state_dict_key_name, world_rank=ort_trainer_opts.distributed.world_rank) else: _save(trainer, checkpoint_dir, state_dict_key_name) -def load_model_optim_state_and_eval(device, trainer_opts, use_lamb=True): - learning_rate = 0.1 - seed = 1 +def create_orttrainer_and_save_checkpoint_bart(device, trainer_opts, checkpoint_dir, state_dict_key_name='state_dict', use_lamb=True, seed=1, learning_rate=0.1): + """Instantiate trainer and save checkpoint for BART. + + - Instantiates the ORTTrainer with given input trainer_opts configuration for a simple BART model + - Loads a dummy optimizer state into the trainer + - Runs eval_step on the trainer so the trainer onnx graph is initialized + - Returns the trainer state_dict, the expected state dict if present, and the onnx model + """ + torch.manual_seed(seed) + set_seed(seed) + + ort_trainer_opts = orttrainer.ORTTrainerOptions(trainer_opts) + optim_config = optim.LambConfig(lr=learning_rate) if use_lamb else optim.AdamConfig(lr=learning_rate) + model, model_desc = _load_bart_model() + trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=ort_trainer_opts) + + # load dummy optimizer state as we are not going to run real training + dummy_init_state = generate_dummy_optim_state(model, optim_config) + init_state = copy.deepcopy(dummy_init_state) + trainer.load_state_dict(dummy_init_state) + + # run an eval step to innitialize the graph + src_tokens, prev_output_tokens, target = generate_random_input_from_bart_model_desc(model_desc, seed = seed) + trainer.eval_step(src_tokens, prev_output_tokens, target) + + # save current model parameters as a checkpoint + if checkpoint_dir: + if _is_model_parallel_run(ort_trainer_opts): + _save(trainer, checkpoint_dir, state_dict_key_name, world_rank=ort_trainer_opts.distributed.world_rank) + # save the initial complete model and optimizer states + if ort_trainer_opts.distributed.world_rank == 0: + init_state['model'] = {'full_precision': dict()} + for initializer in model.graph.initializer: + init_state['model']['full_precision'][initializer.name] = numpy_helper.to_array(initializer) + with open(os.path.join(checkpoint_dir, 'expected_state_dict.pkl'), "wb") as f: + pickle.dump(init_state, f) + else: + _save(trainer, checkpoint_dir, state_dict_key_name) +def load_model_optim_state_and_eval(device, trainer_opts, use_lamb=True, seed=1, learning_rate=0.1): torch.manual_seed(seed) set_seed(seed) diff --git a/orttraining/orttraining/test/python/checkpoint/orttraining_test_checkpoint_aggregation.py b/orttraining/orttraining/test/python/checkpoint/orttraining_test_checkpoint_aggregation.py index b7864edb851b0..dee7da9c86890 100644 --- a/orttraining/orttraining/test/python/checkpoint/orttraining_test_checkpoint_aggregation.py +++ b/orttraining/orttraining/test/python/checkpoint/orttraining_test_checkpoint_aggregation.py @@ -21,7 +21,7 @@ import onnxruntime from onnxruntime.training import checkpoint -from _test_helpers import distributed_setup, create_orttrainer_and_load_checkpoint, aggregate_states +from _test_helpers import distributed_setup, create_orttrainer_and_load_checkpoint, create_orttrainer_and_load_checkpoint_bart, aggregate_states from _test_commons import assert_all_states_close_ort @@ -40,6 +40,17 @@ def test_zero_aggregation(checkpoint_dir, loaded_state_dict, is_mixedprecision): assert_all_states_close_ort(aggregate_state_dict_from_test, aggregate_state_dict_from_checkpoint, reshape_states=True) +def test_megatron_aggregation(checkpoint_dir, loaded_state_dict, expected_state_dict, is_mixedprecision): + # get aggregated state dict independently + aggregate_state_dict_from_checkpoint = \ + checkpoint.aggregate_checkpoints(glob.glob(os.path.join(checkpoint_dir, "checkpoint*.ortcp")), pytorch_format=False) + + # verify loaded state and aggregated states match: + assert_all_states_close_ort(loaded_state_dict, aggregate_state_dict_from_checkpoint) + + #compare with expected state dict + assert_all_states_close_ort(expected_state_dict, loaded_state_dict) + def test_aggregation_from_distributed_zero_full_precision_adam(device='cuda', checkpoint_dir='checkpoint_dir/distributed_zero/full_precision/adam/'): opts = {'device': {'id': device}, 'debug': {'deterministic_compute': True}} @@ -88,12 +99,115 @@ def test_aggregation_from_distributed_zero_mixed_precision_lamb(device='cuda', c test_zero_aggregation(checkpoint_dir, loaded_state_dict, is_mixedprecision=True) +def test_aggregation_from_distributed_megatron_full_precision_adam(device='cuda', checkpoint_dir='checkpoint_dir/distributed_megatron/full_precision/adam/'): + opts = {'device': {'id': device}, + 'debug': {'deterministic_compute': True}} + + # extract state dictionaries to compare + loaded_state_dict, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir, use_lamb=False) + test_megatron_aggregation(checkpoint_dir, loaded_state_dict, expected_state_dict, is_mixedprecision=False) + + +def test_aggregation_from_distributed_megatron_mixed_precision_adam(device='cuda', checkpoint_dir='checkpoint_dir/distributed_megatron/mixed_precision/adam/'): + opts = { + 'device': {'id': device}, + 'mixed_precision': + { + 'enabled': True + }, + 'debug': {'deterministic_compute': True} + } + + # extract state dictionaries to compare + loaded_state_dict, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir, use_lamb=False) + test_megatron_aggregation(checkpoint_dir, loaded_state_dict, expected_state_dict, is_mixedprecision=True) + + +def test_aggregation_from_distributed_megatron_full_precision_lamb(device='cuda', checkpoint_dir='checkpoint_dir/distributed_megatron/full_precision/lamb/'): + opts = {'device': {'id': device}, + 'debug': {'deterministic_compute': True}} + + # extract state dictionaries to compare + loaded_state_dict, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir, use_lamb=True) + test_megatron_aggregation(checkpoint_dir, loaded_state_dict, expected_state_dict, is_mixedprecision=False) + + +def test_aggregation_from_distributed_megatron_mixed_precision_lamb(device='cuda', checkpoint_dir='checkpoint_dir/distributed_megatron/mixed_precision/lamb/'): + opts = { + 'device': {'id': device}, + 'mixed_precision': + { + 'enabled': True + }, + 'debug': {'deterministic_compute': True} + } + + # extract state dictionaries to compare + loaded_state_dict, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir, use_lamb=True) + test_megatron_aggregation(checkpoint_dir, loaded_state_dict, expected_state_dict, is_mixedprecision=True) + +def test_aggregation_from_distributed_zero_megatron_full_precision_adam(device='cuda', checkpoint_dir='checkpoint_dir/distributed_zero_megatron/full_precision/adam/'): + opts = {'device': {'id': device}, + 'debug': {'deterministic_compute': True}} + + # extract state dictionaries to compare + loaded_state_dict, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir, use_lamb=False) + test_megatron_aggregation(checkpoint_dir, loaded_state_dict, expected_state_dict, is_mixedprecision=False) + + +def test_aggregation_from_distributed_zero_megatron_mixed_precision_adam(device='cuda', checkpoint_dir='checkpoint_dir/distributed_zero_megatron/mixed_precision/adam/'): + opts = { + 'device': {'id': device}, + 'mixed_precision': + { + 'enabled': True + }, + 'debug': {'deterministic_compute': True} + } + + # extract state dictionaries to compare + loaded_state_dict, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir, use_lamb=False) + test_megatron_aggregation(checkpoint_dir, loaded_state_dict, expected_state_dict, is_mixedprecision=True) + + +def test_aggregation_from_distributed_zero_megatron_full_precision_lamb(device='cuda', checkpoint_dir='checkpoint_dir/distributed_zero_megatron/full_precision/lamb/'): + opts = {'device': {'id': device}, + 'debug': {'deterministic_compute': True}} + + # extract state dictionaries to compare + loaded_state_dict, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir, use_lamb=True) + test_megatron_aggregation(checkpoint_dir, loaded_state_dict, expected_state_dict, is_mixedprecision=False) + + +def test_aggregation_from_distributed_zero_megatron_mixed_precision_lamb(device='cuda', checkpoint_dir='checkpoint_dir/distributed_zero_megatron/mixed_precision/lamb/'): + opts = { + 'device': {'id': device}, + 'mixed_precision': + { + 'enabled': True + }, + 'debug': {'deterministic_compute': True} + } + + # extract state dictionaries to compare + loaded_state_dict, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir, use_lamb=True) + test_megatron_aggregation(checkpoint_dir, loaded_state_dict, expected_state_dict, is_mixedprecision=True) + + function_map = { # all config to single node config 'test_aggregation_from_distributed_zero_full_precision_adam': test_aggregation_from_distributed_zero_full_precision_adam, 'test_aggregation_from_distributed_zero_mixed_precision_adam': test_aggregation_from_distributed_zero_mixed_precision_adam, 'test_aggregation_from_distributed_zero_mixed_precision_lamb': test_aggregation_from_distributed_zero_mixed_precision_lamb, - 'test_aggregation_from_distributed_zero_full_precision_lamb': test_aggregation_from_distributed_zero_full_precision_lamb + 'test_aggregation_from_distributed_zero_full_precision_lamb': test_aggregation_from_distributed_zero_full_precision_lamb, + 'test_aggregation_from_distributed_megatron_full_precision_adam': test_aggregation_from_distributed_megatron_full_precision_adam, + 'test_aggregation_from_distributed_megatron_mixed_precision_adam': test_aggregation_from_distributed_megatron_mixed_precision_adam, + 'test_aggregation_from_distributed_megatron_mixed_precision_lamb': test_aggregation_from_distributed_megatron_mixed_precision_lamb, + 'test_aggregation_from_distributed_megatron_full_precision_lamb': test_aggregation_from_distributed_megatron_full_precision_lamb, + 'test_aggregation_from_distributed_zero_megatron_full_precision_adam': test_aggregation_from_distributed_zero_megatron_full_precision_adam, + 'test_aggregation_from_distributed_zero_megatron_mixed_precision_adam': test_aggregation_from_distributed_zero_megatron_mixed_precision_adam, + 'test_aggregation_from_distributed_zero_megatron_mixed_precision_lamb': test_aggregation_from_distributed_zero_megatron_mixed_precision_lamb, + 'test_aggregation_from_distributed_zero_megatron_full_precision_lamb': test_aggregation_from_distributed_zero_megatron_full_precision_lamb } parser = argparse.ArgumentParser(description='Test aggregation of states for Zero-1') parser.add_argument('--scenario', choices=function_map.keys(), help='training scenario to test saved and loaded states', required=True) diff --git a/orttraining/orttraining/test/python/checkpoint/orttraining_test_load_checkpoint.py b/orttraining/orttraining/test/python/checkpoint/orttraining_test_load_checkpoint.py index acc2c371ac9a7..ffcc6eea69642 100644 --- a/orttraining/orttraining/test/python/checkpoint/orttraining_test_load_checkpoint.py +++ b/orttraining/orttraining/test/python/checkpoint/orttraining_test_load_checkpoint.py @@ -19,10 +19,10 @@ import onnxruntime from onnxruntime.training import checkpoint -from _test_helpers import distributed_setup, create_orttrainer_and_load_checkpoint, aggregate_states, assert_all_states_close +from _test_helpers import distributed_setup, create_orttrainer_and_load_checkpoint, create_orttrainer_and_load_checkpoint_bart, aggregate_states, assert_all_states_close, save_ort_ckpt from _test_commons import assert_all_states_close_ort, assert_all_states_close_pytorch -def test_load_from_single_node_full_precision_into_single_node_full_precision(device = 'cuda', checkpoint_dir = 'checkpoint_dir/single_node/full_precision/'): +def test_load_from_single_node_full_precision_into_single_node_full_precision(checkpoint_dir, device = 'cuda'): opts = {'device' : {'id' : device}, 'debug' : {'deterministic_compute': True}} @@ -32,7 +32,7 @@ def test_load_from_single_node_full_precision_into_single_node_full_precision(de # compare all states assert_all_states_close(checkpoint_dir, 'state_dict', state_dict_post_checkpoint, model) -def test_load_from_single_node_mixed_precision_into_single_node_full_precision(device = 'cuda', checkpoint_dir = 'checkpoint_dir/single_node/mixed_precision/'): +def test_load_from_single_node_mixed_precision_into_single_node_full_precision(checkpoint_dir, device = 'cuda'): opts = {'device' : {'id' : device}, 'debug' : {'deterministic_compute': True}} @@ -42,7 +42,7 @@ def test_load_from_single_node_mixed_precision_into_single_node_full_precision(d # compare all states assert_all_states_close(checkpoint_dir, 'state_dict', state_dict_post_checkpoint, model) -def test_load_from_single_node_mixed_precision_into_single_node_mixed_precision(device = 'cuda', checkpoint_dir = 'checkpoint_dir/single_node/mixed_precision/'): +def test_load_from_single_node_mixed_precision_into_single_node_mixed_precision(checkpoint_dir, device = 'cuda'): opts = { 'device' : {'id' : device}, 'mixed_precision': @@ -58,7 +58,7 @@ def test_load_from_single_node_mixed_precision_into_single_node_mixed_precision( # compare all states assert_all_states_close(checkpoint_dir, 'state_dict', state_dict_post_checkpoint, model) -def test_load_from_single_node_full_precision_into_single_node_mixed_precision(device = 'cuda', checkpoint_dir = 'checkpoint_dir/single_node/full_precision/'): +def test_load_from_single_node_full_precision_into_single_node_mixed_precision(checkpoint_dir, device = 'cuda'): opts = { 'device' : {'id' : device}, 'mixed_precision': @@ -74,7 +74,7 @@ def test_load_from_single_node_full_precision_into_single_node_mixed_precision(d # compare all states assert_all_states_close(checkpoint_dir, 'state_dict', state_dict_post_checkpoint, model) -def test_load_from_data_parallelism_full_precision_into_single_node_full_precision(device = 'cuda', checkpoint_dir = 'checkpoint_dir/data_parallelism/full_precision/'): +def test_load_from_data_parallelism_full_precision_into_single_node_full_precision(checkpoint_dir, device = 'cuda'): opts = {'device' : {'id' : device}, 'debug' : {'deterministic_compute': True}} @@ -84,7 +84,7 @@ def test_load_from_data_parallelism_full_precision_into_single_node_full_precisi # compare all states assert_all_states_close(checkpoint_dir, 'state_dict', state_dict_post_checkpoint, model) -def test_load_from_data_parallelism_mixed_precision_into_single_node_full_precision(device = 'cuda', checkpoint_dir = 'checkpoint_dir/data_parallelism/mixed_precision/'): +def test_load_from_data_parallelism_mixed_precision_into_single_node_full_precision(checkpoint_dir, device = 'cuda'): opts = {'device' : {'id' : device}, 'debug' : {'deterministic_compute': True}} @@ -94,7 +94,7 @@ def test_load_from_data_parallelism_mixed_precision_into_single_node_full_precis # compare all states assert_all_states_close(checkpoint_dir, 'state_dict', state_dict_post_checkpoint, model) -def test_load_from_data_parallelism_mixed_precision_into_single_node_mixed_precision(device = 'cuda', checkpoint_dir = 'checkpoint_dir/data_parallelism/mixed_precision/'): +def test_load_from_data_parallelism_mixed_precision_into_single_node_mixed_precision(checkpoint_dir, device = 'cuda'): opts = { 'device' : {'id' : device}, 'mixed_precision': @@ -110,7 +110,7 @@ def test_load_from_data_parallelism_mixed_precision_into_single_node_mixed_preci # compare all states assert_all_states_close(checkpoint_dir, 'state_dict', state_dict_post_checkpoint, model) -def test_load_from_data_parallelism_full_precision_into_single_node_mixed_precision(device = 'cuda', checkpoint_dir = 'checkpoint_dir/data_parallelism/full_precision/'): +def test_load_from_data_parallelism_full_precision_into_single_node_mixed_precision(checkpoint_dir, device = 'cuda'): opts = { 'device' : {'id' : device}, 'mixed_precision': @@ -126,7 +126,7 @@ def test_load_from_data_parallelism_full_precision_into_single_node_mixed_precis # compare all states assert_all_states_close(checkpoint_dir, 'state_dict', state_dict_post_checkpoint, model) -def test_load_from_distributed_zero_full_precision_into_single_node_full_precision(device = 'cuda', checkpoint_dir = 'checkpoint_dir/distributed_zero/full_precision/'): +def test_load_from_distributed_zero_full_precision_into_single_node_full_precision(checkpoint_dir, device = 'cuda'): opts = {'device' : {'id' : device}, 'debug' : {'deterministic_compute': True}} @@ -144,7 +144,7 @@ def test_load_from_distributed_zero_full_precision_into_single_node_full_precisi agg_state_dict = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=True) assert_all_states_close_pytorch(agg_state_dict, model) -def test_load_from_distributed_zero_mixed_precision_into_single_node_full_precision(device = 'cuda', checkpoint_dir = 'checkpoint_dir/distributed_zero/mixed_precision/lamb'): +def test_load_from_distributed_zero_mixed_precision_into_single_node_full_precision(checkpoint_dir, device = 'cuda'): opts = {'device' : {'id' : device}, 'debug' : {'deterministic_compute': True}} @@ -162,7 +162,7 @@ def test_load_from_distributed_zero_mixed_precision_into_single_node_full_precis agg_state_dict = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=True) assert_all_states_close_pytorch(agg_state_dict, model) -def test_load_from_distributed_zero_mixed_precision_into_single_node_mixed_precision(device = 'cuda', checkpoint_dir = 'checkpoint_dir/distributed_zero/mixed_precision/lamb'): +def test_load_from_distributed_zero_mixed_precision_into_single_node_mixed_precision(checkpoint_dir, device = 'cuda'): opts = { 'device' : {'id' : device}, 'mixed_precision': @@ -186,7 +186,7 @@ def test_load_from_distributed_zero_mixed_precision_into_single_node_mixed_preci agg_state_dict = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=True) assert_all_states_close_pytorch(agg_state_dict, model) -def test_load_from_distributed_zero_full_precision_into_single_node_mixed_precision(device = 'cuda', checkpoint_dir = 'checkpoint_dir/distributed_zero/full_precision/lamb/'): +def test_load_from_distributed_zero_full_precision_into_single_node_mixed_precision(checkpoint_dir, device = 'cuda'): opts = { 'device' : {'id' : device}, 'mixed_precision': @@ -210,8 +210,106 @@ def test_load_from_distributed_zero_full_precision_into_single_node_mixed_precis agg_state_dict = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=True) assert_all_states_close_pytorch(agg_state_dict, model) +def test_load_from_megatron_to_non_model_parallel_node(state_dict_post_checkpoint, expected_state_dict, model, checkpoint_dir): + # compare the expected dictionary with the aggregated state dictionary from the ORTTrainer + assert_all_states_close_ort(expected_state_dict, state_dict_post_checkpoint) + + # TODO: aggregate checkpoints previously saved and load it into the pytorch model for comparison, + # need to add support to add the bart pytorch model to unit tests instead of current onnx model + # checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'checkpoint*.ortcp')) + # agg_state_dict = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=True) + # assert_all_states_close_pytorch(agg_state_dict, model) + +def test_load_from_distributed_megatron_full_precision_into_single_node_full_precision(checkpoint_dir, device = 'cuda'): + opts = {'device' : {'id' : device}, + 'debug' : {'deterministic_compute': True}} + + # extract state dictionaries to compare + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + test_load_from_megatron_to_non_model_parallel_node(state_dict_post_checkpoint, expected_state_dict, model, checkpoint_dir) + +def test_load_from_distributed_megatron_mixed_precision_into_single_node_full_precision(checkpoint_dir, device = 'cuda'): + opts = {'device' : {'id' : device}, + 'debug' : {'deterministic_compute': True}} + + # extract state dictionaries to compare + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + test_load_from_megatron_to_non_model_parallel_node(state_dict_post_checkpoint, expected_state_dict, model, checkpoint_dir) + +def test_load_from_distributed_megatron_mixed_precision_into_single_node_mixed_precision(checkpoint_dir, device = 'cuda'): + opts = { + 'device' : {'id' : device}, + 'mixed_precision': + { + 'enabled': True + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + test_load_from_megatron_to_non_model_parallel_node(state_dict_post_checkpoint, expected_state_dict, model, checkpoint_dir) + +def test_load_from_distributed_megatron_full_precision_into_single_node_mixed_precision(checkpoint_dir, device = 'cuda'): + opts = { + 'device' : {'id' : device}, + 'mixed_precision': + { + 'enabled': True + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + test_load_from_megatron_to_non_model_parallel_node(state_dict_post_checkpoint, expected_state_dict, model, checkpoint_dir) + +def test_load_from_distributed_zero_megatron_full_precision_into_single_node_full_precision(checkpoint_dir, device = 'cuda'): + opts = {'device' : {'id' : device}, + 'debug' : {'deterministic_compute': True}} + + # extract state dictionaries to compare + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + test_load_from_megatron_to_non_model_parallel_node(state_dict_post_checkpoint, expected_state_dict, model, checkpoint_dir) + +def test_load_from_distributed_zero_megatron_mixed_precision_into_single_node_full_precision(checkpoint_dir, device = 'cuda'): + opts = {'device' : {'id' : device}, + 'debug' : {'deterministic_compute': True}} + + # extract state dictionaries to compare + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + test_load_from_megatron_to_non_model_parallel_node(state_dict_post_checkpoint, expected_state_dict, model, checkpoint_dir) + +def test_load_from_distributed_zero_megatron_mixed_precision_into_single_node_mixed_precision(checkpoint_dir, device = 'cuda'): + opts = { + 'device' : {'id' : device}, + 'mixed_precision': + { + 'enabled': True + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + test_load_from_megatron_to_non_model_parallel_node(state_dict_post_checkpoint, expected_state_dict, model, checkpoint_dir) + +def test_load_from_distributed_zero_megatron_full_precision_into_single_node_mixed_precision(checkpoint_dir, device = 'cuda'): + opts = { + 'device' : {'id' : device}, + 'mixed_precision': + { + 'enabled': True + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + test_load_from_megatron_to_non_model_parallel_node(state_dict_post_checkpoint, expected_state_dict, model, checkpoint_dir) + @distributed_setup -def test_load_from_single_node_full_precision_into_data_parallelism_full_precision(world_rank, world_size, device, checkpoint_dir = 'checkpoint_dir/single_node/full_precision/'): +def test_load_from_single_node_full_precision_into_data_parallelism_full_precision(world_rank, world_size, device, checkpoint_dir): opts = { 'device' : {'id' : device}, 'distributed' : @@ -230,7 +328,7 @@ def test_load_from_single_node_full_precision_into_data_parallelism_full_precisi assert_all_states_close(checkpoint_dir, 'state_dict', state_dict_post_checkpoint, model) @distributed_setup -def test_load_from_single_node_mixed_precision_into_data_parallelism_full_precision(world_rank, world_size, device, checkpoint_dir = 'checkpoint_dir/single_node/mixed_precision/'): +def test_load_from_single_node_mixed_precision_into_data_parallelism_full_precision(world_rank, world_size, device, checkpoint_dir): opts = { 'device' : {'id' : device}, 'distributed' : @@ -249,7 +347,7 @@ def test_load_from_single_node_mixed_precision_into_data_parallelism_full_precis assert_all_states_close(checkpoint_dir, 'state_dict', state_dict_post_checkpoint, model) @distributed_setup -def test_load_from_single_node_mixed_precision_into_data_parallelism_mixed_precision(world_rank, world_size, device, checkpoint_dir = 'checkpoint_dir/single_node/mixed_precision/'): +def test_load_from_single_node_mixed_precision_into_data_parallelism_mixed_precision(world_rank, world_size, device, checkpoint_dir): opts = { 'device' : {'id' : device}, 'mixed_precision': @@ -272,7 +370,7 @@ def test_load_from_single_node_mixed_precision_into_data_parallelism_mixed_preci assert_all_states_close(checkpoint_dir, 'state_dict', state_dict_post_checkpoint, model) @distributed_setup -def test_load_from_single_node_full_precision_into_data_parallelism_mixed_precision(world_rank, world_size, device, checkpoint_dir = 'checkpoint_dir/single_node/full_precision/'): +def test_load_from_single_node_full_precision_into_data_parallelism_mixed_precision(world_rank, world_size, device, checkpoint_dir): opts = { 'device' : {'id' : device}, 'mixed_precision': @@ -295,7 +393,7 @@ def test_load_from_single_node_full_precision_into_data_parallelism_mixed_precis assert_all_states_close(checkpoint_dir, 'state_dict', state_dict_post_checkpoint, model) @distributed_setup -def test_load_from_data_parallelism_full_precision_into_data_parallelism_full_precision(world_rank, world_size, device, checkpoint_dir = 'checkpoint_dir/data_parallelism/full_precision/'): +def test_load_from_data_parallelism_full_precision_into_data_parallelism_full_precision(world_rank, world_size, device, checkpoint_dir): opts = { 'device' : {'id' : device}, 'distributed' : @@ -314,7 +412,7 @@ def test_load_from_data_parallelism_full_precision_into_data_parallelism_full_pr assert_all_states_close(checkpoint_dir, 'state_dict', state_dict_post_checkpoint, model) @distributed_setup -def test_load_from_data_parallelism_mixed_precision_into_data_parallelism_full_precision(world_rank, world_size, device, checkpoint_dir = 'checkpoint_dir/data_parallelism/mixed_precision/'): +def test_load_from_data_parallelism_mixed_precision_into_data_parallelism_full_precision(world_rank, world_size, device, checkpoint_dir): opts = { 'device' : {'id' : device}, 'distributed' : @@ -333,7 +431,7 @@ def test_load_from_data_parallelism_mixed_precision_into_data_parallelism_full_p assert_all_states_close(checkpoint_dir, 'state_dict', state_dict_post_checkpoint, model) @distributed_setup -def test_load_from_data_parallelism_mixed_precision_into_data_parallelism_mixed_precision(world_rank, world_size, device, checkpoint_dir = 'checkpoint_dir/data_parallelism/mixed_precision/'): +def test_load_from_data_parallelism_mixed_precision_into_data_parallelism_mixed_precision(world_rank, world_size, device, checkpoint_dir): opts = { 'device' : {'id' : device}, 'mixed_precision': @@ -356,7 +454,7 @@ def test_load_from_data_parallelism_mixed_precision_into_data_parallelism_mixed_ assert_all_states_close(checkpoint_dir, 'state_dict', state_dict_post_checkpoint, model) @distributed_setup -def test_load_from_data_parallelism_full_precision_into_data_parallelism_mixed_precision(world_rank, world_size, device, checkpoint_dir = 'checkpoint_dir/data_parallelism/full_precision/'): +def test_load_from_data_parallelism_full_precision_into_data_parallelism_mixed_precision(world_rank, world_size, device, checkpoint_dir): opts = { 'device' : {'id' : device}, 'mixed_precision': @@ -379,7 +477,7 @@ def test_load_from_data_parallelism_full_precision_into_data_parallelism_mixed_p assert_all_states_close(checkpoint_dir, 'state_dict', state_dict_post_checkpoint, model) @distributed_setup -def test_load_from_distributed_zero_full_precision_into_data_parallelism_full_precision(world_rank, world_size, device, checkpoint_dir = 'checkpoint_dir/distributed_zero/full_precision/lamb/'): +def test_load_from_distributed_zero_full_precision_into_data_parallelism_full_precision(world_rank, world_size, device, checkpoint_dir): opts = { 'device' : {'id' : device}, 'distributed' : @@ -406,7 +504,7 @@ def test_load_from_distributed_zero_full_precision_into_data_parallelism_full_pr assert_all_states_close_pytorch(agg_state_dict, model) @distributed_setup -def test_load_from_distributed_zero_mixed_precision_into_data_parallelism_full_precision(world_rank, world_size, device, checkpoint_dir = 'checkpoint_dir/distributed_zero/mixed_precision/lamb'): +def test_load_from_distributed_zero_mixed_precision_into_data_parallelism_full_precision(world_rank, world_size, device, checkpoint_dir): opts = { 'device' : {'id' : device}, 'distributed' : @@ -433,7 +531,7 @@ def test_load_from_distributed_zero_mixed_precision_into_data_parallelism_full_p assert_all_states_close_pytorch(agg_state_dict, model) @distributed_setup -def test_load_from_distributed_zero_mixed_precision_into_data_parallelism_mixed_precision(world_rank, world_size, device, checkpoint_dir = 'checkpoint_dir/distributed_zero/mixed_precision/lamb'): +def test_load_from_distributed_zero_mixed_precision_into_data_parallelism_mixed_precision(world_rank, world_size, device, checkpoint_dir): opts = { 'device' : {'id' : device}, 'mixed_precision': @@ -464,7 +562,7 @@ def test_load_from_distributed_zero_mixed_precision_into_data_parallelism_mixed_ assert_all_states_close_pytorch(agg_state_dict, model) @distributed_setup -def test_load_from_distributed_zero_full_precision_into_data_parallelism_mixed_precision(world_rank, world_size, device, checkpoint_dir = 'checkpoint_dir/distributed_zero/full_precision/lamb/'): +def test_load_from_distributed_zero_full_precision_into_data_parallelism_mixed_precision(world_rank, world_size, device, checkpoint_dir): opts = { 'device' : {'id' : device}, 'mixed_precision': @@ -494,8 +592,161 @@ def test_load_from_distributed_zero_full_precision_into_data_parallelism_mixed_p agg_state_dict = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=True) assert_all_states_close_pytorch(agg_state_dict, model) + +@distributed_setup +def test_load_from_distributed_megatron_full_precision_into_data_parallelism_full_precision(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + test_load_from_megatron_to_non_model_parallel_node(state_dict_post_checkpoint, expected_state_dict, model, checkpoint_dir) + +@distributed_setup +def test_load_from_distributed_megatron_mixed_precision_into_data_parallelism_full_precision(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + test_load_from_megatron_to_non_model_parallel_node(state_dict_post_checkpoint, expected_state_dict, model, checkpoint_dir) + +@distributed_setup +def test_load_from_distributed_megatron_mixed_precision_into_data_parallelism_mixed_precision(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'mixed_precision': + { + 'enabled': True + }, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + test_load_from_megatron_to_non_model_parallel_node(state_dict_post_checkpoint, expected_state_dict, model, checkpoint_dir) + +@distributed_setup +def test_load_from_distributed_megatron_full_precision_into_data_parallelism_mixed_precision(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'mixed_precision': + { + 'enabled': True + }, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + test_load_from_megatron_to_non_model_parallel_node(state_dict_post_checkpoint, expected_state_dict, model, checkpoint_dir) + +@distributed_setup +def test_load_from_distributed_zero_megatron_full_precision_into_data_parallelism_full_precision(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + test_load_from_megatron_to_non_model_parallel_node(state_dict_post_checkpoint, expected_state_dict, model, checkpoint_dir) + +@distributed_setup +def test_load_from_distributed_zero_megatron_mixed_precision_into_data_parallelism_full_precision(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + test_load_from_megatron_to_non_model_parallel_node(state_dict_post_checkpoint, expected_state_dict, model, checkpoint_dir) + +@distributed_setup +def test_load_from_distributed_zero_megatron_mixed_precision_into_data_parallelism_mixed_precision(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'mixed_precision': + { + 'enabled': True + }, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + test_load_from_megatron_to_non_model_parallel_node(state_dict_post_checkpoint, expected_state_dict, model, checkpoint_dir) + +@distributed_setup +def test_load_from_distributed_zero_megatron_full_precision_into_data_parallelism_mixed_precision(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'mixed_precision': + { + 'enabled': True + }, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + test_load_from_megatron_to_non_model_parallel_node(state_dict_post_checkpoint, expected_state_dict, model, checkpoint_dir) + @distributed_setup -def test_load_from_single_node_full_precision_into_distributed_zero_full_precision(world_rank, world_size, device, checkpoint_dir = 'checkpoint_dir/single_node/full_precision/'): +def test_load_from_single_node_full_precision_into_distributed_zero_full_precision(world_rank, world_size, device, checkpoint_dir): opts = { 'device' : {'id' : device}, 'distributed' : @@ -519,7 +770,7 @@ def test_load_from_single_node_full_precision_into_distributed_zero_full_precisi state = pickle.load(f) state_dict_pre_checkpoint = state['state_dict'] - # To compare state dictioanry from a single node trainer to the state dictioanry from a zero run: + # To compare state dictionary from a single node trainer to the state dictionary from a zero run: # - Save the state dictionaries for each rank for the zero run in a pickle file (distributed_state_world_rank.pkl) # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. # - Compare the aggregated state dictionary against the state dictionary previously saved from the single node run. @@ -539,7 +790,7 @@ def test_load_from_single_node_full_precision_into_distributed_zero_full_precisi os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl')) @distributed_setup -def test_load_from_single_node_mixed_precision_into_distributed_zero_full_precision(world_rank, world_size, device, checkpoint_dir = 'checkpoint_dir/single_node/mixed_precision/'): +def test_load_from_single_node_mixed_precision_into_distributed_zero_full_precision(world_rank, world_size, device, checkpoint_dir): opts = { 'device' : {'id' : device}, 'distributed' : @@ -563,7 +814,7 @@ def test_load_from_single_node_mixed_precision_into_distributed_zero_full_precis state = pickle.load(f) state_dict_pre_checkpoint = state['state_dict'] - # To compare state dictioanry from a single node trainer to the state dictioanry from a zero run: + # To compare state dictionary from a single node trainer to the state dictionary from a zero run: # - Save the state dictionaries for each rank for the zero run in a pickle file (distributed_state_world_rank.pkl) # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. # - Compare the aggregated state dictionary against the state dictionary previously saved from the single node run. @@ -583,7 +834,7 @@ def test_load_from_single_node_mixed_precision_into_distributed_zero_full_precis os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl')) @distributed_setup -def test_load_from_single_node_mixed_precision_into_distributed_zero_mixed_precision(world_rank, world_size, device, checkpoint_dir = 'checkpoint_dir/single_node/mixed_precision/'): +def test_load_from_single_node_mixed_precision_into_distributed_zero_mixed_precision(world_rank, world_size, device, checkpoint_dir): opts = { 'device' : {'id' : device}, 'mixed_precision': @@ -611,7 +862,7 @@ def test_load_from_single_node_mixed_precision_into_distributed_zero_mixed_preci state = pickle.load(f) state_dict_pre_checkpoint = state['state_dict'] - # To compare state dictioanry from a single node trainer to the state dictioanry from a zero run: + # To compare state dictionary from a single node trainer to the state dictionary from a zero run: # - Save the state dictionaries for each rank for the zero run in a pickle file (distributed_state_world_rank.pkl) # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. # - Compare the aggregated state dictionary against the state dictionary previously saved from the single node run. @@ -631,7 +882,7 @@ def test_load_from_single_node_mixed_precision_into_distributed_zero_mixed_preci os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl')) @distributed_setup -def test_load_from_single_node_full_precision_into_distributed_zero_mixed_precision(world_rank, world_size, device, checkpoint_dir = 'checkpoint_dir/single_node/full_precision/'): +def test_load_from_single_node_full_precision_into_distributed_zero_mixed_precision(world_rank, world_size, device, checkpoint_dir): opts = { 'device' : {'id' : device}, 'mixed_precision': @@ -659,7 +910,7 @@ def test_load_from_single_node_full_precision_into_distributed_zero_mixed_precis state = pickle.load(f) state_dict_pre_checkpoint = state['state_dict'] - # To compare state dictioanry from a single node trainer to the state dictioanry from a zero run: + # To compare state dictionary from a single node trainer to the state dictionary from a zero run: # - Save the state dictionaries for each rank for the zero run in a pickle file (distributed_state_world_rank.pkl) # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. # - Compare the aggregated state dictionary against the state dictionary previously saved from the single node run. @@ -679,7 +930,7 @@ def test_load_from_single_node_full_precision_into_distributed_zero_mixed_precis os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl')) @distributed_setup -def test_load_from_data_parallelism_full_precision_into_distributed_zero_full_precision(world_rank, world_size, device, checkpoint_dir = 'checkpoint_dir/data_parallelism/full_precision/'): +def test_load_from_data_parallelism_full_precision_into_distributed_zero_full_precision(world_rank, world_size, device, checkpoint_dir): opts = { 'device' : {'id' : device}, 'distributed' : @@ -703,7 +954,7 @@ def test_load_from_data_parallelism_full_precision_into_distributed_zero_full_pr state = pickle.load(f) state_dict_pre_checkpoint = state['state_dict'] - # To compare state dictioanry from a data parallel node trainer to the state dictioanry from a zero run: + # To compare state dictionary from a data parallel node trainer to the state dictionary from a zero run: # - Save the state dictionaries for each rank for the zero run in a pickle file (distributed_state_world_rank.pkl) # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. # - Compare the aggregated state dictionary against the state dictionary previously saved from the data parallel node run. @@ -723,7 +974,7 @@ def test_load_from_data_parallelism_full_precision_into_distributed_zero_full_pr os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl')) @distributed_setup -def test_load_from_data_parallelism_mixed_precision_into_distributed_zero_full_precision(world_rank, world_size, device, checkpoint_dir = 'checkpoint_dir/data_parallelism/mixed_precision/'): +def test_load_from_data_parallelism_mixed_precision_into_distributed_zero_full_precision(world_rank, world_size, device, checkpoint_dir): opts = { 'device' : {'id' : device}, 'distributed' : @@ -747,7 +998,7 @@ def test_load_from_data_parallelism_mixed_precision_into_distributed_zero_full_p state = pickle.load(f) state_dict_pre_checkpoint = state['state_dict'] - # To compare state dictioanry from a data parallel node trainer to the state dictioanry from a zero run: + # To compare state dictionary from a data parallel node trainer to the state dictionary from a zero run: # - Save the state dictionaries for each rank for the zero run in a pickle file (distributed_state_world_rank.pkl) # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. # - Compare the aggregated state dictionary against the state dictionary previously saved from the data parallel node run. @@ -767,7 +1018,7 @@ def test_load_from_data_parallelism_mixed_precision_into_distributed_zero_full_p os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl')) @distributed_setup -def test_load_from_data_parallelism_mixed_precision_into_distributed_zero_mixed_precision(world_rank, world_size, device, checkpoint_dir = 'checkpoint_dir/data_parallelism/mixed_precision/'): +def test_load_from_data_parallelism_mixed_precision_into_distributed_zero_mixed_precision(world_rank, world_size, device, checkpoint_dir): opts = { 'device' : {'id' : device}, 'mixed_precision': @@ -795,7 +1046,7 @@ def test_load_from_data_parallelism_mixed_precision_into_distributed_zero_mixed_ state = pickle.load(f) state_dict_pre_checkpoint = state['state_dict'] - # To compare state dictioanry from a data parallel node trainer to the state dictioanry from a zero run: + # To compare state dictionary from a data parallel node trainer to the state dictionary from a zero run: # - Save the state dictionaries for each rank for the zero run in a pickle file (distributed_state_world_rank.pkl) # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. # - Compare the aggregated state dictionary against the state dictionary previously saved from the data parallel node run. @@ -815,7 +1066,7 @@ def test_load_from_data_parallelism_mixed_precision_into_distributed_zero_mixed_ os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl')) @distributed_setup -def test_load_from_data_parallelism_full_precision_into_distributed_zero_mixed_precision(world_rank, world_size, device, checkpoint_dir = 'checkpoint_dir/data_parallelism/full_precision/'): +def test_load_from_data_parallelism_full_precision_into_distributed_zero_mixed_precision(world_rank, world_size, device, checkpoint_dir): opts = { 'device' : {'id' : device}, 'mixed_precision': @@ -843,7 +1094,7 @@ def test_load_from_data_parallelism_full_precision_into_distributed_zero_mixed_p state = pickle.load(f) state_dict_pre_checkpoint = state['state_dict'] - # To compare state dictioanry from a data parallel node trainer to the state dictioanry from a zero run: + # To compare state dictionary from a data parallel node trainer to the state dictionary from a zero run: # - Save the state dictionaries for each rank for the zero run in a pickle file (distributed_state_world_rank.pkl) # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. # - Compare the aggregated state dictionary against the state dictionary previously saved from the data parallel node run. @@ -863,7 +1114,7 @@ def test_load_from_data_parallelism_full_precision_into_distributed_zero_mixed_p os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl')) @distributed_setup -def test_load_from_distributed_zero_full_precision_into_distributed_zero_full_precision(world_rank, world_size, device, checkpoint_dir = 'checkpoint_dir/distributed_zero/full_precision/lamb/'): +def test_load_from_distributed_zero_full_precision_into_distributed_zero_full_precision(world_rank, world_size, device, checkpoint_dir): opts = { 'device' : {'id' : device}, 'distributed' : @@ -891,7 +1142,7 @@ def test_load_from_distributed_zero_full_precision_into_distributed_zero_full_pr assert_all_states_close_ort(state_dict_pre_checkpoint, state_dict_post_checkpoint) @distributed_setup -def test_load_from_distributed_zero_mixed_precision_into_distributed_zero_full_precision(world_rank, world_size, device, checkpoint_dir = 'checkpoint_dir/distributed_zero/mixed_precision/lamb'): +def test_load_from_distributed_zero_mixed_precision_into_distributed_zero_full_precision(world_rank, world_size, device, checkpoint_dir): opts = { 'device' : {'id' : device}, 'distributed' : @@ -910,7 +1161,7 @@ def test_load_from_distributed_zero_mixed_precision_into_distributed_zero_full_p # extract state dictionaries to compare state_dict_post_checkpoint, _ = create_orttrainer_and_load_checkpoint(device, opts, checkpoint_dir) - # To compare state dictioanry between two distributed zero node trainers (with different mixed precision parameter): + # To compare state dictionary between two distributed zero node trainers (with different mixed precision parameter): # - Save the state dictionaries for each rank for the current zero run in a pickle file (distributed_state_world_rank.pkl) # - On rank 0, manually load each state dictionary (distributed_state_world_rank.pkl) and aggregate all of them into a single state dictionary. # - Aggregate the checkpoint files from the previous zero run checkpoint files into a single state dictionary. @@ -939,7 +1190,7 @@ def test_load_from_distributed_zero_mixed_precision_into_distributed_zero_full_p os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl')) @distributed_setup -def test_load_from_distributed_zero_mixed_precision_into_distributed_zero_mixed_precision(world_rank, world_size, device, checkpoint_dir = 'checkpoint_dir/distributed_zero/mixed_precision/lamb'): +def test_load_from_distributed_zero_mixed_precision_into_distributed_zero_mixed_precision(world_rank, world_size, device, checkpoint_dir): opts = { 'device' : {'id' : device}, 'mixed_precision': @@ -971,7 +1222,7 @@ def test_load_from_distributed_zero_mixed_precision_into_distributed_zero_mixed_ assert_all_states_close_ort(state_dict_pre_checkpoint, state_dict_post_checkpoint) @distributed_setup -def test_load_from_distributed_zero_full_precision_into_distributed_zero_mixed_precision(world_rank, world_size, device, checkpoint_dir = 'checkpoint_dir/distributed_zero/full_precision/lamb/'): +def test_load_from_distributed_zero_full_precision_into_distributed_zero_mixed_precision(world_rank, world_size, device, checkpoint_dir): opts = { 'device' : {'id' : device}, 'mixed_precision': @@ -994,7 +1245,7 @@ def test_load_from_distributed_zero_full_precision_into_distributed_zero_mixed_p # extract state dictionaries to compare state_dict_post_checkpoint, _ = create_orttrainer_and_load_checkpoint(device, opts, checkpoint_dir) - # To compare state dictioanry between two distributed zero node trainers (with different mixed precision parameter): + # To compare state dictionary between two distributed zero node trainers (with different mixed precision parameter): # - Save the state dictionaries for each rank for the current zero run in a pickle file (distributed_state_world_rank.pkl) # - On rank 0, manually load each state dictionary (distributed_state_world_rank.pkl) and aggregate all of them into a single state dictionary. # - Aggregate the checkpoint files from the previous zero run checkpoint files into a single state dictionary. @@ -1022,6 +1273,2043 @@ def test_load_from_distributed_zero_full_precision_into_distributed_zero_mixed_p dist.barrier() os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl')) +@distributed_setup +def test_load_from_distributed_megatron_full_precision_into_distributed_zero_full_precision(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True, + 'deepspeed_zero_optimization': + { + 'stage': 1 + } + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + + # To compare state dictionary between distributed megatron and distributed zero node trainers: + # - Save the state dictionaries for each rank for the current zero run in a pickle file (distributed_state_world_rank.pkl) + # - On rank 0, manually load each state dictionary (distributed_state_world_rank.pkl) and aggregate all of them into a single state dictionary. + # - Aggregate the checkpoint files from the previous zero run checkpoint files into a single state dictionary. + # - Compare the aggregated state dictionary from the current run against the aggregated state dictionary from the previous run. + # This is needed because of difference in model-parallel config causing different sharding of model and optimizer states + with open(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl'), "wb") as f: + pickle.dump(state_dict_post_checkpoint, f) + dist.barrier() + + if world_rank == 0: + # manually aggregate the states for the current full precision zero trainer + aggregated_state_dict_loaded = aggregate_states(checkpoint_dir, filename_prefix='distributed_state', state_dict_key_name=None) + + # compare the two state dictionaries + assert_all_states_close_ort(expected_state_dict, aggregated_state_dict_loaded, reshape_states=True) + + dist.barrier() + os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl')) + +@distributed_setup +def test_load_from_distributed_megatron_mixed_precision_into_distributed_zero_full_precision(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True, + 'deepspeed_zero_optimization': + { + 'stage': 1 + } + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + + # To compare state dictionary between distributed megatron and distributed zero node trainers: + # - Save the state dictionaries for each rank for the current zero run in a pickle file (distributed_state_world_rank.pkl) + # - On rank 0, manually load each state dictionary (distributed_state_world_rank.pkl) and aggregate all of them into a single state dictionary. + # - Aggregate the checkpoint files from the previous zero run checkpoint files into a single state dictionary. + # - Compare the aggregated state dictionary from the current run against the aggregated state dictionary from the previous run. + # This is needed because of difference in model-parallel config causing different sharding of model and optimizer states + with open(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl'), "wb") as f: + pickle.dump(state_dict_post_checkpoint, f) + dist.barrier() + + if world_rank == 0: + # manually aggregate the states for the current full precision zero trainer + aggregated_state_dict_loaded = aggregate_states(checkpoint_dir, filename_prefix='distributed_state', state_dict_key_name=None) + + # compare the two state dictionaries + assert_all_states_close_ort(expected_state_dict, aggregated_state_dict_loaded, reshape_states=True) + + dist.barrier() + os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl')) + +@distributed_setup +def test_load_from_distributed_megatron_mixed_precision_into_distributed_zero_mixed_precision(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'mixed_precision': + { + 'enabled': True + }, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True, + 'deepspeed_zero_optimization': + { + 'stage': 1 + } + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + + # To compare state dictionary between distributed megatron and distributed zero node trainers: + # - Save the state dictionaries for each rank for the current zero run in a pickle file (distributed_state_world_rank.pkl) + # - On rank 0, manually load each state dictionary (distributed_state_world_rank.pkl) and aggregate all of them into a single state dictionary. + # - Aggregate the checkpoint files from the previous zero run checkpoint files into a single state dictionary. + # - Compare the aggregated state dictionary from the current run against the aggregated state dictionary from the previous run. + # This is needed because of difference in model-parallel config causing different sharding of model and optimizer states + with open(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl'), "wb") as f: + pickle.dump(state_dict_post_checkpoint, f) + dist.barrier() + + if world_rank == 0: + # manually aggregate the states for the current full precision zero trainer + aggregated_state_dict_loaded = aggregate_states(checkpoint_dir, filename_prefix='distributed_state', state_dict_key_name=None) + + # compare the two state dictionaries + assert_all_states_close_ort(expected_state_dict, aggregated_state_dict_loaded, reshape_states=True) + + dist.barrier() + os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl')) + +@distributed_setup +def test_load_from_distributed_megatron_full_precision_into_distributed_zero_mixed_precision(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'mixed_precision': + { + 'enabled': True + }, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True, + 'deepspeed_zero_optimization': + { + 'stage': 1 + } + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + + # To compare state dictionary between distributed megatron and distributed zero node trainers: + # - Save the state dictionaries for each rank for the current zero run in a pickle file (distributed_state_world_rank.pkl) + # - On rank 0, manually load each state dictionary (distributed_state_world_rank.pkl) and aggregate all of them into a single state dictionary. + # - Aggregate the checkpoint files from the previous zero run checkpoint files into a single state dictionary. + # - Compare the aggregated state dictionary from the current run against the aggregated state dictionary from the previous run. + # This is needed because of difference in model-parallel config causing different sharding of model and optimizer states + with open(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl'), "wb") as f: + pickle.dump(state_dict_post_checkpoint, f) + dist.barrier() + + if world_rank == 0: + # manually aggregate the states for the current full precision zero trainer + aggregated_state_dict_loaded = aggregate_states(checkpoint_dir, filename_prefix='distributed_state', state_dict_key_name=None) + + # compare the two state dictionaries + assert_all_states_close_ort(expected_state_dict, aggregated_state_dict_loaded, reshape_states=True) + + dist.barrier() + os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl')) + +@distributed_setup +def test_load_from_distributed_zero_megatron_full_precision_into_distributed_zero_full_precision(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True, + 'deepspeed_zero_optimization': + { + 'stage': 1 + } + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + + # To compare state dictionary between distributed zero+megatron and distributed zero node trainers: + # - Save the state dictionaries for each rank for the current zero run in a pickle file (distributed_state_world_rank.pkl) + # - On rank 0, manually load each state dictionary (distributed_state_world_rank.pkl) and aggregate all of them into a single state dictionary. + # - Aggregate the checkpoint files from the previous zero run checkpoint files into a single state dictionary. + # - Compare the aggregated state dictionary from the current run against the aggregated state dictionary from the previous run. + # This is needed because of difference in model-parallel config causing different sharding of model and optimizer states + with open(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl'), "wb") as f: + pickle.dump(state_dict_post_checkpoint, f) + dist.barrier() + + if world_rank == 0: + # manually aggregate the states for the current full precision zero trainer + aggregated_state_dict_loaded = aggregate_states(checkpoint_dir, filename_prefix='distributed_state', state_dict_key_name=None) + + # compare the two state dictionaries + assert_all_states_close_ort(expected_state_dict, aggregated_state_dict_loaded, reshape_states=True) + + dist.barrier() + os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl')) + +@distributed_setup +def test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_zero_full_precision(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True, + 'deepspeed_zero_optimization': + { + 'stage': 1 + } + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + + # To compare state dictionary between distributed zero+megatron and distributed zero node trainers: + # - Save the state dictionaries for each rank for the current zero run in a pickle file (distributed_state_world_rank.pkl) + # - On rank 0, manually load each state dictionary (distributed_state_world_rank.pkl) and aggregate all of them into a single state dictionary. + # - Aggregate the checkpoint files from the previous zero run checkpoint files into a single state dictionary. + # - Compare the aggregated state dictionary from the current run against the aggregated state dictionary from the previous run. + # This is needed because of difference in model-parallel config causing different sharding of model and optimizer states + with open(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl'), "wb") as f: + pickle.dump(state_dict_post_checkpoint, f) + dist.barrier() + + if world_rank == 0: + # manually aggregate the states for the current full precision zero trainer + aggregated_state_dict_loaded = aggregate_states(checkpoint_dir, filename_prefix='distributed_state', state_dict_key_name=None) + + # compare the two state dictionaries + assert_all_states_close_ort(expected_state_dict, aggregated_state_dict_loaded, reshape_states=True) + + dist.barrier() + os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl')) + +@distributed_setup +def test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_zero_mixed_precision(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'mixed_precision': + { + 'enabled': True + }, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True, + 'deepspeed_zero_optimization': + { + 'stage': 1 + } + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + + # To compare state dictionary between distributed zero+megatron and distributed zero node trainers: + # - Save the state dictionaries for each rank for the current zero run in a pickle file (distributed_state_world_rank.pkl) + # - On rank 0, manually load each state dictionary (distributed_state_world_rank.pkl) and aggregate all of them into a single state dictionary. + # - Aggregate the checkpoint files from the previous zero run checkpoint files into a single state dictionary. + # - Compare the aggregated state dictionary from the current run against the aggregated state dictionary from the previous run. + # This is needed because of difference in model-parallel config causing different sharding of model and optimizer states + with open(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl'), "wb") as f: + pickle.dump(state_dict_post_checkpoint, f) + dist.barrier() + + if world_rank == 0: + # manually aggregate the states for the current full precision zero trainer + aggregated_state_dict_loaded = aggregate_states(checkpoint_dir, filename_prefix='distributed_state', state_dict_key_name=None) + + # compare the two state dictionaries + assert_all_states_close_ort(expected_state_dict, aggregated_state_dict_loaded, reshape_states=True) + + dist.barrier() + os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl')) + +@distributed_setup +def test_load_from_distributed_zero_megatron_full_precision_into_distributed_zero_mixed_precision(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'mixed_precision': + { + 'enabled': True + }, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True, + 'deepspeed_zero_optimization': + { + 'stage': 1 + } + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + + # To compare state dictionary between distributed zero+megatron and distributed zero node trainers: + # - Save the state dictionaries for each rank for the current zero run in a pickle file (distributed_state_world_rank.pkl) + # - On rank 0, manually load each state dictionary (distributed_state_world_rank.pkl) and aggregate all of them into a single state dictionary. + # - Aggregate the checkpoint files from the previous zero run checkpoint files into a single state dictionary. + # - Compare the aggregated state dictionary from the current run against the aggregated state dictionary from the previous run. + # This is needed because of difference in model-parallel config causing different sharding of model and optimizer states + with open(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl'), "wb") as f: + pickle.dump(state_dict_post_checkpoint, f) + dist.barrier() + + if world_rank == 0: + # manually aggregate the states for the current full precision zero trainer + aggregated_state_dict_loaded = aggregate_states(checkpoint_dir, filename_prefix='distributed_state', state_dict_key_name=None) + + # compare the two state dictionaries + assert_all_states_close_ort(expected_state_dict, aggregated_state_dict_loaded, reshape_states=True) + + dist.barrier() + os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl')) + + +########################################################################################################################################### +# LOAD TO MEGATRON +########################################################################################################################################### + +@distributed_setup +def test_load_from_single_node_full_precision_into_distributed_megatron_full_precision(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True, + 'horizontal_parallel_size' : world_size + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, _, _ = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + + state = None + with open(os.path.join(checkpoint_dir, 'state_dict.pkl'), 'rb') as f: + state = pickle.load(f) + state_dict_pre_checkpoint = state['state_dict'] + + # To compare state dictionary from a single node trainer to the state dictionary from a megatron run: + # - Save the state dictionaries for each rank for the megatron run (distributed_state_world_rank.ort.pt) + # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. + # - Compare the aggregated state dictionary against the state dictionary previously saved from the single node run. + filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filepath = os.path.join(checkpoint_dir, filename) + save_ort_ckpt(state_dict_post_checkpoint, filepath) + dist.barrier() + + if world_rank == 0: + # manually aggregate the states for the current trainer + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + aggregated_state_dict = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) + + # compare the manually aggregated state dictionary with the expected single node state dictionary + assert_all_states_close_ort(aggregated_state_dict, state_dict_pre_checkpoint) + + dist.barrier() + os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + +@distributed_setup +def test_load_from_single_node_mixed_precision_into_distributed_megatron_full_precision(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True, + 'horizontal_parallel_size' : world_size + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, _, _ = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + + state = None + with open(os.path.join(checkpoint_dir, 'state_dict.pkl'), 'rb') as f: + state = pickle.load(f) + state_dict_pre_checkpoint = state['state_dict'] + + # To compare state dictionary from a single node trainer to the state dictionary from a megatron run: + # - Save the state dictionaries for each rank for the megatron run (distributed_state_world_rank.ort.pt) + # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. + # - Compare the aggregated state dictionary against the state dictionary previously saved from the single node run. + filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filepath = os.path.join(checkpoint_dir, filename) + save_ort_ckpt(state_dict_post_checkpoint, filepath) + dist.barrier() + + if world_rank == 0: + # manually aggregate the states for the current trainer + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + aggregated_state_dict = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) + + # compare the manually aggregated state dictionary with the expected single node state dictionary + assert_all_states_close_ort(aggregated_state_dict, state_dict_pre_checkpoint) + + dist.barrier() + os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + +@distributed_setup +def test_load_from_single_node_mixed_precision_into_distributed_megatron_mixed_precision(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'mixed_precision': + { + 'enabled': True + }, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True, + 'horizontal_parallel_size' : world_size + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, _, _ = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + + state = None + with open(os.path.join(checkpoint_dir, 'state_dict.pkl'), 'rb') as f: + state = pickle.load(f) + state_dict_pre_checkpoint = state['state_dict'] + + # To compare state dictionary from a single node trainer to the state dictionary from a megatron run: + # - Save the state dictionaries for each rank for the megatron run (distributed_state_world_rank.ort.pt) + # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. + # - Compare the aggregated state dictionary against the state dictionary previously saved from the single node run. + filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filepath = os.path.join(checkpoint_dir, filename) + save_ort_ckpt(state_dict_post_checkpoint, filepath) + dist.barrier() + + if world_rank == 0: + # manually aggregate the states for the current trainer + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + aggregated_state_dict = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) + + # compare the manually aggregated state dictionary with the expected single node state dictionary + assert_all_states_close_ort(aggregated_state_dict, state_dict_pre_checkpoint) + + dist.barrier() + os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + +@distributed_setup +def test_load_from_single_node_full_precision_into_distributed_megatron_mixed_precision(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'mixed_precision': + { + 'enabled': True + }, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True, + 'horizontal_parallel_size' : world_size + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, _, _ = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + + state = None + with open(os.path.join(checkpoint_dir, 'state_dict.pkl'), 'rb') as f: + state = pickle.load(f) + state_dict_pre_checkpoint = state['state_dict'] + + # To compare state dictionary from a single node trainer to the state dictionary from a megatron run: + # - Save the state dictionaries for each rank for the megatron run (distributed_state_world_rank.ort.pt) + # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. + # - Compare the aggregated state dictionary against the state dictionary previously saved from the single node run. + filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filepath = os.path.join(checkpoint_dir, filename) + save_ort_ckpt(state_dict_post_checkpoint, filepath) + dist.barrier() + + if world_rank == 0: + # manually aggregate the states for the current trainer + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + aggregated_state_dict = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) + + # compare the manually aggregated state dictionary with the expected single node state dictionary + assert_all_states_close_ort(aggregated_state_dict, state_dict_pre_checkpoint) + + dist.barrier() + os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + +@distributed_setup +def test_load_from_data_parallelism_full_precision_into_distributed_megatron_full_precision(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True, + 'horizontal_parallel_size' : world_size + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, _, _ = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + + state = None + with open(os.path.join(checkpoint_dir, 'state_dict.pkl'), 'rb') as f: + state = pickle.load(f) + state_dict_pre_checkpoint = state['state_dict'] + + # To compare state dictionary from a data parallel trainer to the state dictionary from a megatron run: + # - Save the state dictionaries for each rank for the megatron run (distributed_state_world_rank.ort.pt) + # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. + # - Compare the aggregated state dictionary against the state dictionary previously saved from the data parallel node run.. + filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filepath = os.path.join(checkpoint_dir, filename) + save_ort_ckpt(state_dict_post_checkpoint, filepath) + dist.barrier() + + if world_rank == 0: + # manually aggregate the states for the current trainer + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + aggregated_state_dict = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) + + # compare the manually aggregated state dictionary with the expected data parallel state dictionary + assert_all_states_close_ort(aggregated_state_dict, state_dict_pre_checkpoint) + + dist.barrier() + os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + +@distributed_setup +def test_load_from_data_parallelism_mixed_precision_into_distributed_megatron_full_precision(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True, + 'horizontal_parallel_size' : world_size + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, _, _ = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + + state = None + with open(os.path.join(checkpoint_dir, 'state_dict.pkl'), 'rb') as f: + state = pickle.load(f) + state_dict_pre_checkpoint = state['state_dict'] + + # To compare state dictionary from a data parallel trainer to the state dictionary from a megatron run: + # - Save the state dictionaries for each rank for the megatron run (distributed_state_world_rank.ort.pt) + # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. + # - Compare the aggregated state dictionary against the state dictionary previously saved from the data parallel node run. + filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filepath = os.path.join(checkpoint_dir, filename) + save_ort_ckpt(state_dict_post_checkpoint, filepath) + dist.barrier() + + if world_rank == 0: + # manually aggregate the states for the current trainer + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + aggregated_state_dict = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) + + # compare the manually aggregated state dictionary with the expected data parallel state dictionary + assert_all_states_close_ort(aggregated_state_dict, state_dict_pre_checkpoint) + + dist.barrier() + os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + +@distributed_setup +def test_load_from_data_parallelism_mixed_precision_into_distributed_megatron_mixed_precision(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'mixed_precision': + { + 'enabled': True + }, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True, + 'horizontal_parallel_size' : world_size + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, _, _ = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + + state = None + with open(os.path.join(checkpoint_dir, 'state_dict.pkl'), 'rb') as f: + state = pickle.load(f) + state_dict_pre_checkpoint = state['state_dict'] + + # To compare state dictionary from a data parallel trainer to the state dictionary from a megatron run: + # - Save the state dictionaries for each rank for the megatron run (distributed_state_world_rank.ort.pt) + # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. + # - Compare the aggregated state dictionary against the state dictionary previously saved from the data parallel run. + filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filepath = os.path.join(checkpoint_dir, filename) + save_ort_ckpt(state_dict_post_checkpoint, filepath) + dist.barrier() + + if world_rank == 0: + # manually aggregate the states for the current trainer + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + aggregated_state_dict = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) + + # compare the manually aggregated state dictionary with the expected data parallel state dictionary + assert_all_states_close_ort(aggregated_state_dict, state_dict_pre_checkpoint) + + dist.barrier() + os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + +@distributed_setup +def test_load_from_data_parallelism_full_precision_into_distributed_megatron_mixed_precision(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'mixed_precision': + { + 'enabled': True + }, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True, + 'horizontal_parallel_size' : world_size + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, _, _ = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + + state = None + with open(os.path.join(checkpoint_dir, 'state_dict.pkl'), 'rb') as f: + state = pickle.load(f) + state_dict_pre_checkpoint = state['state_dict'] + + # To compare state dictionary from a data parallel trainer to the state dictionary from a megatron run: + # - Save the state dictionaries for each rank for the megatron run (distributed_state_world_rank.ort.pt) + # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. + # - Compare the aggregated state dictionary against the state dictionary previously saved from the data parallel run. + filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filepath = os.path.join(checkpoint_dir, filename) + save_ort_ckpt(state_dict_post_checkpoint, filepath) + dist.barrier() + + if world_rank == 0: + # manually aggregate the states for the current trainer + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + aggregated_state_dict = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) + + # compare the manually aggregated state dictionary with the expected data parallel state dictionary + assert_all_states_close_ort(aggregated_state_dict, state_dict_pre_checkpoint) + + dist.barrier() + os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + +@distributed_setup +def test_load_from_distributed_zero_full_precision_into_distributed_megatron_full_precision(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True, + 'horizontal_parallel_size' : world_size + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + + # To compare state dictionary between distributed megatron and distributed zero node trainers: + # - Save the state dictionaries for each rank for the megatron run (distributed_state_world_rank.ort.pt) + # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. + # - Compare the aggregated state dictionary from the current run against the expected state dictionary from the previous run. + # This is needed because of difference in model-parallel config causing different sharding of model and optimizer states + filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filepath = os.path.join(checkpoint_dir, filename) + save_ort_ckpt(state_dict_post_checkpoint, filepath) + dist.barrier() + + if world_rank == 0: + # manually aggregate the states for the current full precision zero trainer + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + aggregated_state_dict_loaded = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) + + # compare the two state dictionaries + assert_all_states_close_ort(expected_state_dict, aggregated_state_dict_loaded, reshape_states=True) + + dist.barrier() + os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + +@distributed_setup +def test_load_from_distributed_zero_mixed_precision_into_distributed_megatron_full_precision(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True, + 'horizontal_parallel_size' : world_size + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + + # To compare state dictionary between distributed megatron and distributed zero node trainers: + # - Save the state dictionaries for each rank for the megatron run (distributed_state_world_rank.ort.pt) + # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. + # - Compare the aggregated state dictionary from the current run against the expected state dictionary from the previous run. + # This is needed because of difference in model-parallel config causing different sharding of model and optimizer states + filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filepath = os.path.join(checkpoint_dir, filename) + save_ort_ckpt(state_dict_post_checkpoint, filepath) + dist.barrier() + + if world_rank == 0: + # manually aggregate the states for the current trainer + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + aggregated_state_dict_loaded = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) + + # compare the two state dictionaries + assert_all_states_close_ort(expected_state_dict, aggregated_state_dict_loaded, reshape_states=True) + + dist.barrier() + os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + +@distributed_setup +def test_load_from_distributed_zero_mixed_precision_into_distributed_megatron_mixed_precision(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'mixed_precision': + { + 'enabled': True + }, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True, + 'horizontal_parallel_size' : world_size + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + + # To compare state dictionary between distributed megatron and distributed zero node trainers: + # - Save the state dictionaries for each rank for the megatron run (distributed_state_world_rank.ort.pt) + # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. + # - Compare the aggregated state dictionary from the current run against the expected state dictionary from the previous run. + # This is needed because of difference in model-parallel config causing different sharding of model and optimizer states + filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filepath = os.path.join(checkpoint_dir, filename) + save_ort_ckpt(state_dict_post_checkpoint, filepath) + dist.barrier() + + if world_rank == 0: + # manually aggregate the states for the current trainer + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + aggregated_state_dict_loaded = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) + + # compare the two state dictionaries + assert_all_states_close_ort(expected_state_dict, aggregated_state_dict_loaded, reshape_states=True) + + dist.barrier() + os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + +@distributed_setup +def test_load_from_distributed_zero_full_precision_into_distributed_megatron_mixed_precision(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'mixed_precision': + { + 'enabled': True + }, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True, + 'horizontal_parallel_size' : world_size + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + + # To compare state dictionary between distributed megatron and distributed zero node trainers: + # - Save the state dictionaries for each rank for the megatron run (distributed_state_world_rank.ort.pt) + # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. + # - Compare the aggregated state dictionary from the current run against the expected state dictionary from the previous run. + # This is needed because of difference in model-parallel config causing different sharding of model and optimizer states + filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filepath = os.path.join(checkpoint_dir, filename) + save_ort_ckpt(state_dict_post_checkpoint, filepath) + dist.barrier() + + if world_rank == 0: + # manually aggregate the states for the current trainer + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + aggregated_state_dict_loaded = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) + + # compare the two state dictionaries + assert_all_states_close_ort(expected_state_dict, aggregated_state_dict_loaded, reshape_states=True) + + dist.barrier() + os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + +@distributed_setup +def test_load_from_distributed_megatron_full_precision_into_distributed_megatron_full_precision(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True, + 'horizontal_parallel_size' : world_size + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, _ , _ = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + + state = None + with open(os.path.join(checkpoint_dir, 'state_dict_'+str(world_rank)+'.pkl'), 'rb') as f: + state = pickle.load(f) + state_dict_pre_checkpoint = state['state_dict_'+str(world_rank)] + + # compare all states for each rank independently + assert_all_states_close_ort(state_dict_pre_checkpoint, state_dict_post_checkpoint) + +@distributed_setup +def test_load_from_distributed_megatron_mixed_precision_into_distributed_megatron_full_precision(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True, + 'horizontal_parallel_size' : world_size + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + + # To compare state dictionary between distributed megatron and distributed megatron node trainers: + # - Save the state dictionaries for each rank for the megatron run (distributed_state_world_rank.ort.pt) + # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. + # - Compare the aggregated state dictionary from the current run against the expected state dictionary from the previous run. + # This is needed because of difference in model-parallel config causing different sharding of model and optimizer states + filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filepath = os.path.join(checkpoint_dir, filename) + save_ort_ckpt(state_dict_post_checkpoint, filepath) + dist.barrier() + + if world_rank == 0: + # manually aggregate the states for the current trainer + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + aggregated_state_dict_loaded = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) + + # compare the two state dictionaries + assert_all_states_close_ort(expected_state_dict, aggregated_state_dict_loaded, reshape_states=True) + + dist.barrier() + os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + +@distributed_setup +def test_load_from_distributed_megatron_mixed_precision_into_distributed_megatron_mixed_precision(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'mixed_precision': + { + 'enabled': True + }, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True, + 'horizontal_parallel_size' : world_size + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, _ , _ = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + + state = None + with open(os.path.join(checkpoint_dir, 'state_dict_'+str(world_rank)+'.pkl'), 'rb') as f: + state = pickle.load(f) + state_dict_pre_checkpoint = state['state_dict_'+str(world_rank)] + + # compare all states for each rank independently + assert_all_states_close_ort(state_dict_pre_checkpoint, state_dict_post_checkpoint) + +@distributed_setup +def test_load_from_distributed_megatron_full_precision_into_distributed_megatron_mixed_precision(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'mixed_precision': + { + 'enabled': True + }, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True, + 'horizontal_parallel_size' : world_size + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + + # To compare state dictionary between distributed megatron and distributed megatron node trainers: + # - Save the state dictionaries for each rank for the megatron run (distributed_state_world_rank.ort.pt) + # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. + # - Compare the aggregated state dictionary from the current run against the expected state dictionary from the previous run. + # This is needed because of difference in model-parallel config causing different sharding of model and optimizer states + filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filepath = os.path.join(checkpoint_dir, filename) + save_ort_ckpt(state_dict_post_checkpoint, filepath) + dist.barrier() + + if world_rank == 0: + # manually aggregate the states for the current trainer + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + aggregated_state_dict_loaded = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) + + # compare the two state dictionaries + assert_all_states_close_ort(expected_state_dict, aggregated_state_dict_loaded, reshape_states=True) + + dist.barrier() + os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + +@distributed_setup +def test_load_from_distributed_zero_megatron_full_precision_into_distributed_megatron_full_precision(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True, + 'horizontal_parallel_size' : world_size + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + + # To compare state dictionary between distributed megatron and distributed zero+megatron node trainers: + # - Save the state dictionaries for each rank for the megatron run (distributed_state_world_rank.ort.pt) + # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. + # - Compare the aggregated state dictionary from the current run against the expected state dictionary from the previous run. + # This is needed because of difference in model-parallel config causing different sharding of model and optimizer states + filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filepath = os.path.join(checkpoint_dir, filename) + save_ort_ckpt(state_dict_post_checkpoint, filepath) + dist.barrier() + + if world_rank == 0: + # manually aggregate the states for the current trainer + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + aggregated_state_dict_loaded = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) + + # compare the two state dictionaries + assert_all_states_close_ort(expected_state_dict, aggregated_state_dict_loaded, reshape_states=True) + + dist.barrier() + os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + +@distributed_setup +def test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_megatron_full_precision(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True, + 'horizontal_parallel_size' : world_size + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + + # To compare state dictionary between distributed megatron and distributed zero+megatron node trainers : + # - Save the state dictionaries for each rank for the megatron run (distributed_state_world_rank.ort.pt) + # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. + # - Compare the aggregated state dictionary from the current run against the expected state dictionary from the previous run. + # This is needed because of difference in model-parallel config causing different sharding of model and optimizer states + filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filepath = os.path.join(checkpoint_dir, filename) + save_ort_ckpt(state_dict_post_checkpoint, filepath) + dist.barrier() + + if world_rank == 0: + # manually aggregate the states for the current trainer + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + aggregated_state_dict_loaded = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) + + # compare the two state dictionaries + assert_all_states_close_ort(expected_state_dict, aggregated_state_dict_loaded, reshape_states=True) + + dist.barrier() + os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + +@distributed_setup +def test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_megatron_mixed_precision(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'mixed_precision': + { + 'enabled': True + }, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True, + 'horizontal_parallel_size' : world_size + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + + # To compare state dictionary between distributed megatron and distributed zero+megatron node trainers: + # - Save the state dictionaries for each rank for the megatron run (distributed_state_world_rank.ort.pt) + # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. + # - Compare the aggregated state dictionary from the current run against the expected state dictionary from the previous run. + # This is needed because of difference in model-parallel config causing different sharding of model and optimizer states + filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filepath = os.path.join(checkpoint_dir, filename) + save_ort_ckpt(state_dict_post_checkpoint, filepath) + dist.barrier() + + if world_rank == 0: + # manually aggregate the states for the current trainer + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + aggregated_state_dict_loaded = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) + + # compare the two state dictionaries + assert_all_states_close_ort(expected_state_dict, aggregated_state_dict_loaded, reshape_states=True) + + dist.barrier() + os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + +@distributed_setup +def test_load_from_distributed_zero_megatron_full_precision_into_distributed_megatron_mixed_precision(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'mixed_precision': + { + 'enabled': True + }, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True, + 'horizontal_parallel_size' : world_size + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + + # To compare state dictionary between distributed megatron and distributed zero+megatron node trainers: + # - Save the state dictionaries for each rank for the megatron run (distributed_state_world_rank.ort.pt) + # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. + # - Compare the aggregated state dictionary from the current run against the expected state dictionary from the previous run. + # This is needed because of difference in model-parallel config causing different sharding of model and optimizer states + filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filepath = os.path.join(checkpoint_dir, filename) + save_ort_ckpt(state_dict_post_checkpoint, filepath) + dist.barrier() + + if world_rank == 0: + # manually aggregate the states for the current trainer + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + aggregated_state_dict_loaded = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) + + # compare the two state dictionaries + assert_all_states_close_ort(expected_state_dict, aggregated_state_dict_loaded, reshape_states=True) + + dist.barrier() + os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + +########################################################################################################################################### +# LOAD TO ZERO+MEGATRON +########################################################################################################################################### + +@distributed_setup +def test_load_from_single_node_full_precision_into_distributed_zero_megatron_full_precision(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True, + 'horizontal_parallel_size' : int(world_size/2), + 'deepspeed_zero_optimization': + { + 'stage': 1 + } + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, _, _ = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + + state = None + with open(os.path.join(checkpoint_dir, 'state_dict.pkl'), 'rb') as f: + state = pickle.load(f) + state_dict_pre_checkpoint = state['state_dict'] + + # To compare state dictionary from a single node trainer to the state dictionary from a zero+megatron run: + # - Save the state dictionaries for each rank for the zero+megatron run (distributed_state_world_rank.ort.pt) + # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. + # - Compare the aggregated state dictionary against the state dictionary previously saved from the single node run. + filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filepath = os.path.join(checkpoint_dir, filename) + save_ort_ckpt(state_dict_post_checkpoint, filepath) + dist.barrier() + + if world_rank == 0: + # manually aggregate the states for the current full precision zero trainer + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + aggregated_state_dict = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) + + # compare the manually aggregated state dictionary with the expected single node state dictionary + assert_all_states_close_ort(aggregated_state_dict, state_dict_pre_checkpoint) + + dist.barrier() + os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + +@distributed_setup +def test_load_from_single_node_mixed_precision_into_distributed_zero_megatron_full_precision(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True, + 'horizontal_parallel_size' : int(world_size/2), + 'deepspeed_zero_optimization': + { + 'stage': 1 + } + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, _, _ = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + + state = None + with open(os.path.join(checkpoint_dir, 'state_dict.pkl'), 'rb') as f: + state = pickle.load(f) + state_dict_pre_checkpoint = state['state_dict'] + + # To compare state dictionary from a single node trainer to the state dictionary from a zero+megatron run: + # - Save the state dictionaries for each rank for the zero+megatron run (distributed_state_world_rank.ort.pt) + # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. + # - Compare the aggregated state dictionary against the state dictionary previously saved from the single node run. + filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filepath = os.path.join(checkpoint_dir, filename) + save_ort_ckpt(state_dict_post_checkpoint, filepath) + dist.barrier() + + if world_rank == 0: + # manually aggregate the states for the current full precision zero trainer + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + aggregated_state_dict = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) + + # compare the manually aggregated state dictionary with the expected single node state dictionary + assert_all_states_close_ort(aggregated_state_dict, state_dict_pre_checkpoint) + + dist.barrier() + os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + +@distributed_setup +def test_load_from_single_node_mixed_precision_into_distributed_zero_megatron_mixed_precision(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'mixed_precision': + { + 'enabled': True + }, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True, + 'horizontal_parallel_size' : int(world_size/2), + 'deepspeed_zero_optimization': + { + 'stage': 1 + } + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, _, _ = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + + state = None + with open(os.path.join(checkpoint_dir, 'state_dict.pkl'), 'rb') as f: + state = pickle.load(f) + state_dict_pre_checkpoint = state['state_dict'] + + # To compare state dictionary from a single node trainer to the state dictionary from a zero+megatron run: + # - Save the state dictionaries for each rank for the zero+megatron run (distributed_state_world_rank.ort.pt) + # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. + # - Compare the aggregated state dictionary against the state dictionary previously saved from the single node run. + filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filepath = os.path.join(checkpoint_dir, filename) + save_ort_ckpt(state_dict_post_checkpoint, filepath) + dist.barrier() + + if world_rank == 0: + # manually aggregate the states for the current full precision zero trainer + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + aggregated_state_dict = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) + + # compare the manually aggregated state dictionary with the expected single node state dictionary + assert_all_states_close_ort(aggregated_state_dict, state_dict_pre_checkpoint) + + dist.barrier() + os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + +@distributed_setup +def test_load_from_single_node_full_precision_into_distributed_zero_megatron_mixed_precision(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'mixed_precision': + { + 'enabled': True + }, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True, + 'horizontal_parallel_size' : int(world_size/2), + 'deepspeed_zero_optimization': + { + 'stage': 1 + } + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, _, _ = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + + state = None + with open(os.path.join(checkpoint_dir, 'state_dict.pkl'), 'rb') as f: + state = pickle.load(f) + state_dict_pre_checkpoint = state['state_dict'] + + # To compare state dictionary from a single node trainer to the state dictionary from a zero+megatron run: + # - Save the state dictionaries for each rank for the zero+megatron run (distributed_state_world_rank.ort.pt) + # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. + # - Compare the aggregated state dictionary against the state dictionary previously saved from the single node run. + filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filepath = os.path.join(checkpoint_dir, filename) + save_ort_ckpt(state_dict_post_checkpoint, filepath) + dist.barrier() + + if world_rank == 0: + # manually aggregate the states for the current full precision zero trainer + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + aggregated_state_dict = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) + + # compare the manually aggregated state dictionary with the expected single node state dictionary + assert_all_states_close_ort(aggregated_state_dict, state_dict_pre_checkpoint) + + dist.barrier() + os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + +@distributed_setup +def test_load_from_data_parallelism_full_precision_into_distributed_zero_megatron_full_precision(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True, + 'horizontal_parallel_size' : int(world_size/2), + 'deepspeed_zero_optimization': + { + 'stage': 1 + } + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, _, _ = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + + state = None + with open(os.path.join(checkpoint_dir, 'state_dict.pkl'), 'rb') as f: + state = pickle.load(f) + state_dict_pre_checkpoint = state['state_dict'] + + # To compare state dictionary from a data parallel trainer to the state dictionary from a zero+megatron run: + # - Save the state dictionaries for each rank for the zero+megatron run (distributed_state_world_rank.ort.pt) + # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. + # - Compare the aggregated state dictionary against the state dictionary previously saved from the data parallel node run.. + filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filepath = os.path.join(checkpoint_dir, filename) + save_ort_ckpt(state_dict_post_checkpoint, filepath) + dist.barrier() + + if world_rank == 0: + # manually aggregate the states for the current full precision zero trainer + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + aggregated_state_dict = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) + + # compare the manually aggregated state dictionary with the expected data parallel state dictionary + assert_all_states_close_ort(aggregated_state_dict, state_dict_pre_checkpoint) + + dist.barrier() + os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + +@distributed_setup +def test_load_from_data_parallelism_mixed_precision_into_distributed_zero_megatron_full_precision(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True, + 'horizontal_parallel_size' : int(world_size/2), + 'deepspeed_zero_optimization': + { + 'stage': 1 + } + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, _, _ = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + + state = None + with open(os.path.join(checkpoint_dir, 'state_dict.pkl'), 'rb') as f: + state = pickle.load(f) + state_dict_pre_checkpoint = state['state_dict'] + + # To compare state dictionary from a data parallel trainer to the state dictionary from a zero+megatron run: + # - Save the state dictionaries for each rank for the zero+megatron run (distributed_state_world_rank.ort.pt) + # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. + # - Compare the aggregated state dictionary against the state dictionary previously saved from the data parallel node run. + filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filepath = os.path.join(checkpoint_dir, filename) + save_ort_ckpt(state_dict_post_checkpoint, filepath) + dist.barrier() + + if world_rank == 0: + # manually aggregate the states for the current full precision zero trainer + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + aggregated_state_dict = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) + + # compare the manually aggregated state dictionary with the expected data parallel state dictionary + assert_all_states_close_ort(aggregated_state_dict, state_dict_pre_checkpoint) + + dist.barrier() + os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + +@distributed_setup +def test_load_from_data_parallelism_mixed_precision_into_distributed_zero_megatron_mixed_precision(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'mixed_precision': + { + 'enabled': True + }, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True, + 'horizontal_parallel_size' : int(world_size/2), + 'deepspeed_zero_optimization': + { + 'stage': 1 + } + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, _, _ = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + + state = None + with open(os.path.join(checkpoint_dir, 'state_dict.pkl'), 'rb') as f: + state = pickle.load(f) + state_dict_pre_checkpoint = state['state_dict'] + + # To compare state dictionary from a data parallel trainer to the state dictionary from a zero+megatron run: + # - Save the state dictionaries for each rank for the zero+megatron run (distributed_state_world_rank.ort.pt) + # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. + # - Compare the aggregated state dictionary against the state dictionary previously saved from the data parallel run. + filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filepath = os.path.join(checkpoint_dir, filename) + save_ort_ckpt(state_dict_post_checkpoint, filepath) + dist.barrier() + + if world_rank == 0: + # manually aggregate the states for the current full precision zero trainer + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + aggregated_state_dict = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) + + # compare the manually aggregated state dictionary with the expected data parallel state dictionary + assert_all_states_close_ort(aggregated_state_dict, state_dict_pre_checkpoint) + + dist.barrier() + os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + +@distributed_setup +def test_load_from_data_parallelism_full_precision_into_distributed_zero_megatron_mixed_precision(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'mixed_precision': + { + 'enabled': True + }, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True, + 'horizontal_parallel_size' : int(world_size/2), + 'deepspeed_zero_optimization': + { + 'stage': 1 + } + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, _, _ = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + + state = None + with open(os.path.join(checkpoint_dir, 'state_dict.pkl'), 'rb') as f: + state = pickle.load(f) + state_dict_pre_checkpoint = state['state_dict'] + + # To compare state dictionary from a data parallel trainer to the state dictionary from a zero+megatron run: + # - Save the state dictionaries for each rank for the zero+megatron run (distributed_state_world_rank.ort.pt) + # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. + # - Compare the aggregated state dictionary against the state dictionary previously saved from the data parallel run. + filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filepath = os.path.join(checkpoint_dir, filename) + save_ort_ckpt(state_dict_post_checkpoint, filepath) + dist.barrier() + + if world_rank == 0: + # manually aggregate the states for the current full precision zero trainer + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + aggregated_state_dict = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) + + # compare the manually aggregated state dictionary with the expected data parallel state dictionary + assert_all_states_close_ort(aggregated_state_dict, state_dict_pre_checkpoint) + + dist.barrier() + os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + +@distributed_setup +def test_load_from_distributed_zero_full_precision_into_distributed_zero_megatron_full_precision(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True, + 'horizontal_parallel_size' : int(world_size/2), + 'deepspeed_zero_optimization': + { + 'stage': 1 + } + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + + # To compare state dictionary between distributed zero+megatron and distributed zero node trainers: + # - Save the state dictionaries for each rank for the zero+megatron run (distributed_state_world_rank.ort.pt) + # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. + # - Compare the aggregated state dictionary from the current run against the expected state dictionary from the previous run. + # This is needed because of difference in model-parallel config causing different sharding of model and optimizer states + filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filepath = os.path.join(checkpoint_dir, filename) + save_ort_ckpt(state_dict_post_checkpoint, filepath) + dist.barrier() + + if world_rank == 0: + # manually aggregate the states for the current full precision zero trainer + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + aggregated_state_dict_loaded = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) + + # compare the two state dictionaries + assert_all_states_close_ort(expected_state_dict, aggregated_state_dict_loaded, reshape_states=True) + + dist.barrier() + os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + +@distributed_setup +def test_load_from_distributed_zero_mixed_precision_into_distributed_zero_megatron_full_precision(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True, + 'horizontal_parallel_size' : int(world_size/2), + 'deepspeed_zero_optimization': + { + 'stage': 1 + } + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + + # To compare state dictionary between distributed zero+megatron and distributed zero node trainers: + # - Save the state dictionaries for each rank for the zero+megatron run (distributed_state_world_rank.ort.pt) + # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. + # - Compare the aggregated state dictionary from the current run against the expected state dictionary from the previous run. + # This is needed because of difference in model-parallel config causing different sharding of model and optimizer states + filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filepath = os.path.join(checkpoint_dir, filename) + save_ort_ckpt(state_dict_post_checkpoint, filepath) + dist.barrier() + + if world_rank == 0: + # manually aggregate the states for the current full precision zero trainer + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + aggregated_state_dict_loaded = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) + + # compare the two state dictionaries + assert_all_states_close_ort(expected_state_dict, aggregated_state_dict_loaded, reshape_states=True) + + dist.barrier() + os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + +@distributed_setup +def test_load_from_distributed_zero_mixed_precision_into_distributed_zero_megatron_mixed_precision(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'mixed_precision': + { + 'enabled': True + }, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True, + 'horizontal_parallel_size' : int(world_size/2), + 'deepspeed_zero_optimization': + { + 'stage': 1 + } + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + + # To compare state dictionary between distributed zero+megatron and distributed zero node trainers: + # - Save the state dictionaries for each rank for the zero+megatron run (distributed_state_world_rank.ort.pt) + # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. + # - Compare the aggregated state dictionary from the current run against the expected state dictionary from the previous run. + # This is needed because of difference in model-parallel config causing different sharding of model and optimizer states + filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filepath = os.path.join(checkpoint_dir, filename) + save_ort_ckpt(state_dict_post_checkpoint, filepath) + dist.barrier() + + if world_rank == 0: + # manually aggregate the states for the current full precision zero trainer + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + aggregated_state_dict_loaded = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) + + # compare the two state dictionaries + assert_all_states_close_ort(expected_state_dict, aggregated_state_dict_loaded, reshape_states=True) + + dist.barrier() + os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + +@distributed_setup +def test_load_from_distributed_zero_full_precision_into_distributed_zero_megatron_mixed_precision(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'mixed_precision': + { + 'enabled': True + }, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True, + 'horizontal_parallel_size' : int(world_size/2), + 'deepspeed_zero_optimization': + { + 'stage': 1 + } + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + + # To compare state dictionary between distributed zero+megatron and distributed zero node trainers: + # - Save the state dictionaries for each rank for the zero+megatron run (distributed_state_world_rank.ort.pt) + # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. + # - Compare the aggregated state dictionary from the current run against the expected state dictionary from the previous run. + # This is needed because of difference in model-parallel config causing different sharding of model and optimizer states + filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filepath = os.path.join(checkpoint_dir, filename) + save_ort_ckpt(state_dict_post_checkpoint, filepath) + dist.barrier() + + if world_rank == 0: + # manually aggregate the states for the current full precision zero trainer + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + aggregated_state_dict_loaded = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) + + # compare the two state dictionaries + assert_all_states_close_ort(expected_state_dict, aggregated_state_dict_loaded, reshape_states=True) + + dist.barrier() + os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + +@distributed_setup +def test_load_from_distributed_megatron_full_precision_into_distributed_zero_megatron_full_precision(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True, + 'horizontal_parallel_size' : int(world_size/2), + 'deepspeed_zero_optimization': + { + 'stage': 1 + } + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + + # To compare state dictionary between distributed zero+megatron and distributed megatron node trainers: + # - Save the state dictionaries for each rank for the zero+megatron run (distributed_state_world_rank.ort.pt) + # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. + # - Compare the aggregated state dictionary from the current run against the expected state dictionary from the previous run. + # This is needed because of difference in model-parallel config causing different sharding of model and optimizer states + filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filepath = os.path.join(checkpoint_dir, filename) + save_ort_ckpt(state_dict_post_checkpoint, filepath) + dist.barrier() + + if world_rank == 0: + # manually aggregate the states for the current trainer + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + aggregated_state_dict_loaded = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) + + # compare the two state dictionaries + assert_all_states_close_ort(expected_state_dict, aggregated_state_dict_loaded, reshape_states=True) + + dist.barrier() + os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + +@distributed_setup +def test_load_from_distributed_megatron_mixed_precision_into_distributed_zero_megatron_full_precision(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True, + 'horizontal_parallel_size' : int(world_size/2), + 'deepspeed_zero_optimization': + { + 'stage': 1 + } + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + + # To compare state dictionary between distributed zero+megatron and distributed megatron node trainers: + # - Save the state dictionaries for each rank for the zero+megatron run (distributed_state_world_rank.ort.pt) + # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. + # - Compare the aggregated state dictionary from the current run against the expected state dictionary from the previous run. + # This is needed because of difference in model-parallel config causing different sharding of model and optimizer states + filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filepath = os.path.join(checkpoint_dir, filename) + save_ort_ckpt(state_dict_post_checkpoint, filepath) + dist.barrier() + + if world_rank == 0: + # manually aggregate the states for the current trainer + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + aggregated_state_dict_loaded = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) + + # compare the two state dictionaries + assert_all_states_close_ort(expected_state_dict, aggregated_state_dict_loaded, reshape_states=True) + + dist.barrier() + os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + +@distributed_setup +def test_load_from_distributed_megatron_mixed_precision_into_distributed_zero_megatron_mixed_precision(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'mixed_precision': + { + 'enabled': True + }, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True, + 'horizontal_parallel_size' : int(world_size/2), + 'deepspeed_zero_optimization': + { + 'stage': 1 + } + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + + # To compare state dictionary between distributed zero+megatron and distributed megatron node trainers: + # - Save the state dictionaries for each rank for the zero+megatron run (distributed_state_world_rank.ort.pt) + # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. + # - Compare the aggregated state dictionary from the current run against the expected state dictionary from the previous run. + # This is needed because of difference in model-parallel config causing different sharding of model and optimizer states + filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filepath = os.path.join(checkpoint_dir, filename) + save_ort_ckpt(state_dict_post_checkpoint, filepath) + dist.barrier() + + if world_rank == 0: + # manually aggregate the states for the current trainer + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + aggregated_state_dict_loaded = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) + + # compare the two state dictionaries + assert_all_states_close_ort(expected_state_dict, aggregated_state_dict_loaded, reshape_states=True) + + dist.barrier() + os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + +@distributed_setup +def test_load_from_distributed_megatron_full_precision_into_distributed_zero_megatron_mixed_precision(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'mixed_precision': + { + 'enabled': True + }, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True, + 'horizontal_parallel_size' : int(world_size/2), + 'deepspeed_zero_optimization': + { + 'stage': 1 + } + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + + # To compare state dictionary between distributed zero+megatron and distributed megatron node trainers: + # - Save the state dictionaries for each rank for the zero+megatron run (distributed_state_world_rank.ort.pt) + # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. + # - Compare the aggregated state dictionary from the current run against the expected state dictionary from the previous run. + # This is needed because of difference in model-parallel config causing different sharding of model and optimizer states + filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filepath = os.path.join(checkpoint_dir, filename) + save_ort_ckpt(state_dict_post_checkpoint, filepath) + dist.barrier() + + if world_rank == 0: + # manually aggregate the states for the current trainer + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + aggregated_state_dict_loaded = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) + + # compare the two state dictionaries + assert_all_states_close_ort(expected_state_dict, aggregated_state_dict_loaded, reshape_states=True) + + dist.barrier() + os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + +@distributed_setup +def test_load_from_distributed_zero_megatron_full_precision_into_distributed_zero_megatron_full_precision(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True, + 'horizontal_parallel_size' : int(world_size/2), + 'deepspeed_zero_optimization': + { + 'stage': 1 + } + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, _ , _ = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + + state = None + with open(os.path.join(checkpoint_dir, 'state_dict_'+str(world_rank)+'.pkl'), 'rb') as f: + state = pickle.load(f) + state_dict_pre_checkpoint = state['state_dict_'+str(world_rank)] + + # compare all states for each rank independently + assert_all_states_close_ort(state_dict_pre_checkpoint, state_dict_post_checkpoint) + +@distributed_setup +def test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_zero_megatron_full_precision(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True, + 'horizontal_parallel_size' : int(world_size/2), + 'deepspeed_zero_optimization': + { + 'stage': 1 + } + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + + # To compare state dictionary between distributed zero+megatron and distributed zero+megatron node trainers with different precisions: + # - Save the state dictionaries for each rank for the zero+megatron run (distributed_state_world_rank.ort.pt) + # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. + # - Compare the aggregated state dictionary from the current run against the expected state dictionary from the previous run. + # This is needed because of difference in model-parallel config causing different sharding of model and optimizer states + filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filepath = os.path.join(checkpoint_dir, filename) + save_ort_ckpt(state_dict_post_checkpoint, filepath) + dist.barrier() + + if world_rank == 0: + # manually aggregate the states for the current trainer + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + aggregated_state_dict_loaded = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) + + # compare the two state dictionaries + assert_all_states_close_ort(expected_state_dict, aggregated_state_dict_loaded, reshape_states=True) + + dist.barrier() + os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + +@distributed_setup +def test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_zero_megatron_mixed_precision(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'mixed_precision': + { + 'enabled': True + }, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True, + 'horizontal_parallel_size' : int(world_size/2), + 'deepspeed_zero_optimization': + { + 'stage': 1 + } + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, _ , _ = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + + state = None + with open(os.path.join(checkpoint_dir, 'state_dict_'+str(world_rank)+'.pkl'), 'rb') as f: + state = pickle.load(f) + state_dict_pre_checkpoint = state['state_dict_'+str(world_rank)] + + # compare all states for each rank independently + assert_all_states_close_ort(state_dict_pre_checkpoint, state_dict_post_checkpoint) + +@distributed_setup +def test_load_from_distributed_zero_megatron_full_precision_into_distributed_zero_megatron_mixed_precision(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'mixed_precision': + { + 'enabled': True + }, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True, + 'horizontal_parallel_size' : int(world_size/2), + 'deepspeed_zero_optimization': + { + 'stage': 1 + } + }, + 'debug' : {'deterministic_compute': True} + } + + # extract state dictionaries to compare + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + + # To compare state dictionary between distributed zero+megatron and distributed zero+megatron node trainers with different precisions: + # - Save the state dictionaries for each rank for the zero+megatron run (distributed_state_world_rank.ort.pt) + # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. + # - Compare the aggregated state dictionary from the current run against the expected state dictionary from the previous run. + # This is needed because of difference in model-parallel config causing different sharding of model and optimizer states + filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filepath = os.path.join(checkpoint_dir, filename) + save_ort_ckpt(state_dict_post_checkpoint, filepath) + dist.barrier() + + if world_rank == 0: + # manually aggregate the states for the current trainer + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + aggregated_state_dict_loaded = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) + + # compare the two state dictionaries + assert_all_states_close_ort(expected_state_dict, aggregated_state_dict_loaded, reshape_states=True) + + dist.barrier() + os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + function_map = { # all config to single node config 'test_load_from_single_node_full_precision_into_single_node_full_precision': test_load_from_single_node_full_precision_into_single_node_full_precision, @@ -1036,6 +3324,14 @@ def test_load_from_distributed_zero_full_precision_into_distributed_zero_mixed_p 'test_load_from_distributed_zero_mixed_precision_into_single_node_full_precision': test_load_from_distributed_zero_mixed_precision_into_single_node_full_precision, 'test_load_from_distributed_zero_mixed_precision_into_single_node_mixed_precision': test_load_from_distributed_zero_mixed_precision_into_single_node_mixed_precision, 'test_load_from_distributed_zero_full_precision_into_single_node_mixed_precision': test_load_from_distributed_zero_full_precision_into_single_node_mixed_precision, + 'test_load_from_distributed_megatron_full_precision_into_single_node_full_precision': test_load_from_distributed_megatron_full_precision_into_single_node_full_precision, + 'test_load_from_distributed_megatron_mixed_precision_into_single_node_full_precision': test_load_from_distributed_megatron_mixed_precision_into_single_node_full_precision, + 'test_load_from_distributed_megatron_mixed_precision_into_single_node_mixed_precision': test_load_from_distributed_megatron_mixed_precision_into_single_node_mixed_precision, + 'test_load_from_distributed_megatron_full_precision_into_single_node_mixed_precision': test_load_from_distributed_megatron_full_precision_into_single_node_mixed_precision, + 'test_load_from_distributed_zero_megatron_full_precision_into_single_node_full_precision': test_load_from_distributed_zero_megatron_full_precision_into_single_node_full_precision, + 'test_load_from_distributed_zero_megatron_mixed_precision_into_single_node_full_precision': test_load_from_distributed_zero_megatron_mixed_precision_into_single_node_full_precision, + 'test_load_from_distributed_zero_megatron_mixed_precision_into_single_node_mixed_precision': test_load_from_distributed_zero_megatron_mixed_precision_into_single_node_mixed_precision, + 'test_load_from_distributed_zero_megatron_full_precision_into_single_node_mixed_precision': test_load_from_distributed_zero_megatron_full_precision_into_single_node_mixed_precision, # all config to data parallel node config 'test_load_from_single_node_full_precision_into_data_parallelism_full_precision': test_load_from_single_node_full_precision_into_data_parallelism_full_precision, @@ -1050,6 +3346,14 @@ def test_load_from_distributed_zero_full_precision_into_distributed_zero_mixed_p 'test_load_from_distributed_zero_mixed_precision_into_data_parallelism_full_precision': test_load_from_distributed_zero_mixed_precision_into_data_parallelism_full_precision, 'test_load_from_distributed_zero_mixed_precision_into_data_parallelism_mixed_precision': test_load_from_distributed_zero_mixed_precision_into_data_parallelism_mixed_precision, 'test_load_from_distributed_zero_full_precision_into_data_parallelism_mixed_precision': test_load_from_distributed_zero_full_precision_into_data_parallelism_mixed_precision, + 'test_load_from_distributed_megatron_full_precision_into_data_parallelism_full_precision': test_load_from_distributed_megatron_full_precision_into_data_parallelism_full_precision, + 'test_load_from_distributed_megatron_mixed_precision_into_data_parallelism_full_precision': test_load_from_distributed_megatron_mixed_precision_into_data_parallelism_full_precision, + 'test_load_from_distributed_megatron_mixed_precision_into_data_parallelism_mixed_precision': test_load_from_distributed_megatron_mixed_precision_into_data_parallelism_mixed_precision, + 'test_load_from_distributed_megatron_full_precision_into_data_parallelism_mixed_precision': test_load_from_distributed_megatron_full_precision_into_data_parallelism_mixed_precision, + 'test_load_from_distributed_zero_megatron_full_precision_into_data_parallelism_full_precision': test_load_from_distributed_zero_megatron_full_precision_into_data_parallelism_full_precision, + 'test_load_from_distributed_zero_megatron_mixed_precision_into_data_parallelism_full_precision': test_load_from_distributed_zero_megatron_mixed_precision_into_data_parallelism_full_precision, + 'test_load_from_distributed_zero_megatron_mixed_precision_into_data_parallelism_mixed_precision': test_load_from_distributed_zero_megatron_mixed_precision_into_data_parallelism_mixed_precision, + 'test_load_from_distributed_zero_megatron_full_precision_into_data_parallelism_mixed_precision': test_load_from_distributed_zero_megatron_full_precision_into_data_parallelism_mixed_precision, # all config to distributed zero node config 'test_load_from_single_node_full_precision_into_distributed_zero_full_precision': test_load_from_single_node_full_precision_into_distributed_zero_full_precision, @@ -1063,7 +3367,59 @@ def test_load_from_distributed_zero_full_precision_into_distributed_zero_mixed_p 'test_load_from_distributed_zero_full_precision_into_distributed_zero_full_precision': test_load_from_distributed_zero_full_precision_into_distributed_zero_full_precision, 'test_load_from_distributed_zero_mixed_precision_into_distributed_zero_full_precision': test_load_from_distributed_zero_mixed_precision_into_distributed_zero_full_precision, 'test_load_from_distributed_zero_mixed_precision_into_distributed_zero_mixed_precision': test_load_from_distributed_zero_mixed_precision_into_distributed_zero_mixed_precision, - 'test_load_from_distributed_zero_full_precision_into_distributed_zero_mixed_precision': test_load_from_distributed_zero_full_precision_into_distributed_zero_mixed_precision + 'test_load_from_distributed_zero_full_precision_into_distributed_zero_mixed_precision': test_load_from_distributed_zero_full_precision_into_distributed_zero_mixed_precision, + 'test_load_from_distributed_megatron_full_precision_into_distributed_zero_full_precision': test_load_from_distributed_megatron_full_precision_into_distributed_zero_full_precision, + 'test_load_from_distributed_megatron_mixed_precision_into_distributed_zero_full_precision': test_load_from_distributed_megatron_mixed_precision_into_distributed_zero_full_precision, + 'test_load_from_distributed_megatron_mixed_precision_into_distributed_zero_mixed_precision': test_load_from_distributed_megatron_mixed_precision_into_distributed_zero_mixed_precision, + 'test_load_from_distributed_megatron_full_precision_into_distributed_zero_mixed_precision': test_load_from_distributed_megatron_full_precision_into_distributed_zero_mixed_precision, + 'test_load_from_distributed_zero_megatron_full_precision_into_distributed_zero_full_precision': test_load_from_distributed_zero_megatron_full_precision_into_distributed_zero_full_precision, + 'test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_zero_full_precision': test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_zero_full_precision, + 'test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_zero_mixed_precision': test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_zero_mixed_precision, + 'test_load_from_distributed_zero_megatron_full_precision_into_distributed_zero_mixed_precision': test_load_from_distributed_zero_megatron_full_precision_into_distributed_zero_mixed_precision, + + # all config to distributed megatron node config + 'test_load_from_single_node_full_precision_into_distributed_megatron_full_precision': test_load_from_single_node_full_precision_into_distributed_megatron_full_precision, + 'test_load_from_single_node_mixed_precision_into_distributed_megatron_full_precision': test_load_from_single_node_mixed_precision_into_distributed_megatron_full_precision, + 'test_load_from_single_node_mixed_precision_into_distributed_megatron_mixed_precision': test_load_from_single_node_mixed_precision_into_distributed_megatron_mixed_precision, + 'test_load_from_single_node_full_precision_into_distributed_megatron_mixed_precision': test_load_from_single_node_full_precision_into_distributed_megatron_mixed_precision, + 'test_load_from_data_parallelism_full_precision_into_distributed_megatron_full_precision': test_load_from_data_parallelism_full_precision_into_distributed_megatron_full_precision, + 'test_load_from_data_parallelism_mixed_precision_into_distributed_megatron_full_precision': test_load_from_data_parallelism_mixed_precision_into_distributed_megatron_full_precision, + 'test_load_from_data_parallelism_mixed_precision_into_distributed_megatron_mixed_precision': test_load_from_data_parallelism_mixed_precision_into_distributed_megatron_mixed_precision, + 'test_load_from_data_parallelism_full_precision_into_distributed_megatron_mixed_precision': test_load_from_data_parallelism_full_precision_into_distributed_megatron_mixed_precision, + 'test_load_from_distributed_zero_full_precision_into_distributed_megatron_full_precision': test_load_from_distributed_zero_full_precision_into_distributed_megatron_full_precision, + 'test_load_from_distributed_zero_mixed_precision_into_distributed_megatron_full_precision': test_load_from_distributed_zero_mixed_precision_into_distributed_megatron_full_precision, + 'test_load_from_distributed_zero_mixed_precision_into_distributed_megatron_mixed_precision': test_load_from_distributed_zero_mixed_precision_into_distributed_megatron_mixed_precision, + 'test_load_from_distributed_zero_full_precision_into_distributed_megatron_mixed_precision': test_load_from_distributed_zero_full_precision_into_distributed_megatron_mixed_precision, + 'test_load_from_distributed_megatron_full_precision_into_distributed_megatron_full_precision': test_load_from_distributed_megatron_full_precision_into_distributed_megatron_full_precision, + 'test_load_from_distributed_megatron_mixed_precision_into_distributed_megatron_full_precision': test_load_from_distributed_megatron_mixed_precision_into_distributed_megatron_full_precision, + 'test_load_from_distributed_megatron_mixed_precision_into_distributed_megatron_mixed_precision': test_load_from_distributed_megatron_mixed_precision_into_distributed_megatron_mixed_precision, + 'test_load_from_distributed_megatron_full_precision_into_distributed_megatron_mixed_precision': test_load_from_distributed_megatron_full_precision_into_distributed_megatron_mixed_precision, + 'test_load_from_distributed_zero_megatron_full_precision_into_distributed_megatron_full_precision': test_load_from_distributed_zero_megatron_full_precision_into_distributed_megatron_full_precision, + 'test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_megatron_full_precision': test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_megatron_full_precision, + 'test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_megatron_mixed_precision': test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_megatron_mixed_precision, + 'test_load_from_distributed_zero_megatron_full_precision_into_distributed_megatron_mixed_precision': test_load_from_distributed_zero_megatron_full_precision_into_distributed_megatron_mixed_precision, + + # all config to distributed zero + megatron node config + 'test_load_from_single_node_full_precision_into_distributed_zero_megatron_full_precision': test_load_from_single_node_full_precision_into_distributed_zero_megatron_full_precision, + 'test_load_from_single_node_mixed_precision_into_distributed_zero_megatron_full_precision': test_load_from_single_node_mixed_precision_into_distributed_zero_megatron_full_precision, + 'test_load_from_single_node_mixed_precision_into_distributed_zero_megatron_mixed_precision': test_load_from_single_node_mixed_precision_into_distributed_zero_megatron_mixed_precision, + 'test_load_from_single_node_full_precision_into_distributed_zero_megatron_mixed_precision': test_load_from_single_node_full_precision_into_distributed_zero_megatron_mixed_precision, + 'test_load_from_data_parallelism_full_precision_into_distributed_zero_megatron_full_precision': test_load_from_data_parallelism_full_precision_into_distributed_zero_megatron_full_precision, + 'test_load_from_data_parallelism_mixed_precision_into_distributed_zero_megatron_full_precision': test_load_from_data_parallelism_mixed_precision_into_distributed_zero_megatron_full_precision, + 'test_load_from_data_parallelism_mixed_precision_into_distributed_zero_megatron_mixed_precision': test_load_from_data_parallelism_mixed_precision_into_distributed_zero_megatron_mixed_precision, + 'test_load_from_data_parallelism_full_precision_into_distributed_zero_megatron_mixed_precision': test_load_from_data_parallelism_full_precision_into_distributed_zero_megatron_mixed_precision, + 'test_load_from_distributed_zero_full_precision_into_distributed_zero_megatron_full_precision': test_load_from_distributed_zero_full_precision_into_distributed_zero_megatron_full_precision, + 'test_load_from_distributed_zero_mixed_precision_into_distributed_zero_megatron_full_precision': test_load_from_distributed_zero_mixed_precision_into_distributed_zero_megatron_full_precision, + 'test_load_from_distributed_zero_mixed_precision_into_distributed_zero_megatron_mixed_precision': test_load_from_distributed_zero_mixed_precision_into_distributed_zero_megatron_mixed_precision, + 'test_load_from_distributed_zero_full_precision_into_distributed_zero_megatron_mixed_precision': test_load_from_distributed_zero_full_precision_into_distributed_zero_megatron_mixed_precision, + 'test_load_from_distributed_megatron_full_precision_into_distributed_zero_megatron_full_precision': test_load_from_distributed_megatron_full_precision_into_distributed_zero_megatron_full_precision, + 'test_load_from_distributed_megatron_mixed_precision_into_distributed_zero_megatron_full_precision': test_load_from_distributed_megatron_mixed_precision_into_distributed_zero_megatron_full_precision, + 'test_load_from_distributed_megatron_mixed_precision_into_distributed_zero_megatron_mixed_precision': test_load_from_distributed_megatron_mixed_precision_into_distributed_zero_megatron_mixed_precision, + 'test_load_from_distributed_megatron_full_precision_into_distributed_zero_megatron_mixed_precision': test_load_from_distributed_megatron_full_precision_into_distributed_zero_megatron_mixed_precision, + 'test_load_from_distributed_zero_megatron_full_precision_into_distributed_zero_megatron_full_precision': test_load_from_distributed_zero_megatron_full_precision_into_distributed_zero_megatron_full_precision, + 'test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_zero_megatron_full_precision': test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_zero_megatron_full_precision, + 'test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_zero_megatron_mixed_precision': test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_zero_megatron_mixed_precision, + 'test_load_from_distributed_zero_megatron_full_precision_into_distributed_zero_megatron_mixed_precision': test_load_from_distributed_zero_megatron_full_precision_into_distributed_zero_megatron_mixed_precision } parser = argparse.ArgumentParser(description='Test saved states of trainers to loaded states') parser.add_argument('--scenario', choices=function_map.keys(), help='training scenario to test saved and loaded states', required=True) diff --git a/orttraining/orttraining/test/python/checkpoint/orttraining_test_save_checkpoint.py b/orttraining/orttraining/test/python/checkpoint/orttraining_test_save_checkpoint.py index 152b9098100d7..b00b77af81e13 100644 --- a/orttraining/orttraining/test/python/checkpoint/orttraining_test_save_checkpoint.py +++ b/orttraining/orttraining/test/python/checkpoint/orttraining_test_save_checkpoint.py @@ -11,14 +11,14 @@ import sys sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from _test_helpers import distributed_setup, create_orttrainer_and_save_checkpoint +from _test_helpers import distributed_setup, create_orttrainer_and_save_checkpoint, create_orttrainer_and_save_checkpoint_bart -def single_node_full_precision(device = 'cuda', checkpoint_dir = 'checkpoint_dir/single_node/full_precision/'): +def single_node_full_precision(checkpoint_dir, device = 'cuda'): opts = {'device' : {'id' : device}, 'debug' : {'deterministic_compute': True}} create_orttrainer_and_save_checkpoint(device, opts, checkpoint_dir) -def single_node_mixed_precision(device = 'cuda', checkpoint_dir = 'checkpoint_dir/single_node/mixed_precision/'): +def single_node_mixed_precision(checkpoint_dir, device = 'cuda'): opts = { 'device' : {'id' : device}, 'mixed_precision': @@ -29,8 +29,24 @@ def single_node_mixed_precision(device = 'cuda', checkpoint_dir = 'checkpoint_di } create_orttrainer_and_save_checkpoint(device, opts, checkpoint_dir) +def single_node_full_precision_bart(checkpoint_dir, device = 'cuda'): + opts = {'device' : {'id' : device}, + 'debug' : {'deterministic_compute': True}} + create_orttrainer_and_save_checkpoint_bart(device, opts, checkpoint_dir) + +def single_node_mixed_precision_bart(checkpoint_dir, device = 'cuda'): + opts = { + 'device' : {'id' : device}, + 'mixed_precision': + { + 'enabled': True + }, + 'debug' : {'deterministic_compute': True} + } + create_orttrainer_and_save_checkpoint_bart(device, opts, checkpoint_dir) + @distributed_setup -def data_parallelism_full_precision(world_rank, world_size, device, checkpoint_dir = 'checkpoint_dir/data_parallelism/full_precision/'): +def data_parallelism_full_precision(world_rank, world_size, device, checkpoint_dir): opts = { 'device' : {'id' : device}, 'distributed' : @@ -44,7 +60,7 @@ def data_parallelism_full_precision(world_rank, world_size, device, checkpoint_d create_orttrainer_and_save_checkpoint(device, opts, checkpoint_dir if world_rank == 0 else None) @distributed_setup -def data_parallelism_mixed_precision(world_rank, world_size, device, checkpoint_dir = 'checkpoint_dir/data_parallelism/mixed_precision/'): +def data_parallelism_mixed_precision(world_rank, world_size, device, checkpoint_dir): opts = { 'device' : {'id' : device}, 'mixed_precision': @@ -62,7 +78,39 @@ def data_parallelism_mixed_precision(world_rank, world_size, device, checkpoint_ create_orttrainer_and_save_checkpoint(device, opts, checkpoint_dir if world_rank == 0 else None) @distributed_setup -def distributed_zero_full_precision_adam(world_rank, world_size, device, checkpoint_dir = 'checkpoint_dir/distributed_zero/full_precision/adam/'): +def data_parallelism_full_precision_bart(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True + }, + 'debug' : {'deterministic_compute': True} + } + create_orttrainer_and_save_checkpoint_bart(device, opts, checkpoint_dir if world_rank == 0 else None) + +@distributed_setup +def data_parallelism_mixed_precision_bart(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'mixed_precision': + { + 'enabled': True + }, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True + }, + 'debug' : {'deterministic_compute': True} + } + create_orttrainer_and_save_checkpoint_bart(device, opts, checkpoint_dir if world_rank == 0 else None) + +@distributed_setup +def distributed_zero_full_precision_adam(world_rank, world_size, device, checkpoint_dir): opts = { 'device' : {'id' : device}, 'distributed' : @@ -80,7 +128,7 @@ def distributed_zero_full_precision_adam(world_rank, world_size, device, checkpo create_orttrainer_and_save_checkpoint(device, opts, checkpoint_dir, state_dict_key_name='state_dict_'+str(world_rank), use_lamb=False) @distributed_setup -def distributed_zero_mixed_precision_adam(world_rank, world_size, device, checkpoint_dir = 'checkpoint_dir/distributed_zero/mixed_precision/adam/'): +def distributed_zero_mixed_precision_adam(world_rank, world_size, device, checkpoint_dir): opts = { 'device' : {'id' : device}, 'mixed_precision': @@ -102,7 +150,7 @@ def distributed_zero_mixed_precision_adam(world_rank, world_size, device, checkp create_orttrainer_and_save_checkpoint(device, opts, checkpoint_dir, state_dict_key_name='state_dict_'+str(world_rank), use_lamb=False) @distributed_setup -def distributed_zero_full_precision_lamb(world_rank, world_size, device, checkpoint_dir = 'checkpoint_dir/distributed_zero/full_precision/lamb/'): +def distributed_zero_full_precision_lamb(world_rank, world_size, device, checkpoint_dir): opts = { 'device' : {'id' : device}, 'distributed' : @@ -120,7 +168,7 @@ def distributed_zero_full_precision_lamb(world_rank, world_size, device, checkpo create_orttrainer_and_save_checkpoint(device, opts, checkpoint_dir, state_dict_key_name='state_dict_'+str(world_rank)) @distributed_setup -def distributed_zero_mixed_precision_lamb(world_rank, world_size, device, checkpoint_dir = 'checkpoint_dir/distributed_zero/mixed_precision/lamb/'): +def distributed_zero_mixed_precision_lamb(world_rank, world_size, device, checkpoint_dir): opts = { 'device' : {'id' : device}, 'mixed_precision': @@ -141,15 +189,222 @@ def distributed_zero_mixed_precision_lamb(world_rank, world_size, device, checkp } create_orttrainer_and_save_checkpoint(device, opts, checkpoint_dir, state_dict_key_name='state_dict_'+str(world_rank)) +@distributed_setup +def distributed_zero_full_precision_lamb_bart(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True, + 'deepspeed_zero_optimization': + { + 'stage': 1 + } + }, + 'debug' : {'deterministic_compute': True} + } + create_orttrainer_and_save_checkpoint_bart(device, opts, checkpoint_dir, state_dict_key_name='state_dict_'+str(world_rank)) + +@distributed_setup +def distributed_zero_mixed_precision_lamb_bart(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'mixed_precision': + { + 'enabled': True + }, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True, + 'deepspeed_zero_optimization': + { + 'stage': 1 + } + }, + 'debug' : {'deterministic_compute': True} + } + create_orttrainer_and_save_checkpoint_bart(device, opts, checkpoint_dir, state_dict_key_name='state_dict_'+str(world_rank)) + + +@distributed_setup +def distributed_megatron_full_precision_adam(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True, + 'horizontal_parallel_size': world_size + }, + 'debug' : {'deterministic_compute': True} + } + create_orttrainer_and_save_checkpoint_bart(device, opts, checkpoint_dir, state_dict_key_name='state_dict_'+str(world_rank), use_lamb=False) + +@distributed_setup +def distributed_megatron_mixed_precision_adam(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'mixed_precision': + { + 'enabled': True + }, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True, + 'horizontal_parallel_size': world_size + }, + 'debug' : {'deterministic_compute': True} + } + create_orttrainer_and_save_checkpoint_bart(device, opts, checkpoint_dir, state_dict_key_name='state_dict_'+str(world_rank), use_lamb=False) + +@distributed_setup +def distributed_megatron_full_precision_lamb(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True, + 'horizontal_parallel_size': world_size + }, + 'debug' : {'deterministic_compute': True} + } + create_orttrainer_and_save_checkpoint_bart(device, opts, checkpoint_dir, state_dict_key_name='state_dict_'+str(world_rank)) + +@distributed_setup +def distributed_megatron_mixed_precision_lamb(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'mixed_precision': + { + 'enabled': True + }, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True, + 'horizontal_parallel_size': world_size + }, + 'debug' : {'deterministic_compute': True} + } + create_orttrainer_and_save_checkpoint_bart(device, opts, checkpoint_dir, state_dict_key_name='state_dict_'+str(world_rank)) + +@distributed_setup +def distributed_zero_megatron_full_precision_adam(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True, + 'deepspeed_zero_optimization': + { + 'stage': 1 + }, + 'horizontal_parallel_size': int(world_size/2) + }, + 'debug' : {'deterministic_compute': True} + } + create_orttrainer_and_save_checkpoint_bart(device, opts, checkpoint_dir, state_dict_key_name='state_dict_'+str(world_rank), use_lamb=False) + +@distributed_setup +def distributed_zero_megatron_mixed_precision_adam(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'mixed_precision': + { + 'enabled': True + }, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True, + 'deepspeed_zero_optimization': + { + 'stage': 1 + }, + 'horizontal_parallel_size': int(world_size/2) + }, + 'debug' : {'deterministic_compute': True} + } + create_orttrainer_and_save_checkpoint_bart(device, opts, checkpoint_dir, state_dict_key_name='state_dict_'+str(world_rank), use_lamb=False) + +@distributed_setup +def distributed_zero_megatron_full_precision_lamb(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True, + 'deepspeed_zero_optimization': + { + 'stage': 1 + }, + 'horizontal_parallel_size': int(world_size/2) + }, + 'debug' : {'deterministic_compute': True} + } + create_orttrainer_and_save_checkpoint_bart(device, opts, checkpoint_dir, state_dict_key_name='state_dict_'+str(world_rank)) + +@distributed_setup +def distributed_zero_megatron_mixed_precision_lamb(world_rank, world_size, device, checkpoint_dir): + opts = { + 'device' : {'id' : device}, + 'mixed_precision': + { + 'enabled': True + }, + 'distributed' : + { + 'world_rank' : world_rank, + 'world_size' : world_size, + 'allreduce_post_accumulation' : True, + 'deepspeed_zero_optimization': + { + 'stage': 1 + }, + 'horizontal_parallel_size': int(world_size/2) + }, + 'debug' : {'deterministic_compute': True} + } + create_orttrainer_and_save_checkpoint_bart(device, opts, checkpoint_dir, state_dict_key_name='state_dict_'+str(world_rank)) + function_map = { 'single_node_full_precision': single_node_full_precision, 'single_node_mixed_precision': single_node_mixed_precision, + 'single_node_full_precision_bart': single_node_full_precision_bart, + 'single_node_mixed_precision_bart': single_node_mixed_precision_bart, 'data_parallelism_full_precision': data_parallelism_full_precision, 'data_parallelism_mixed_precision': data_parallelism_mixed_precision, + 'data_parallelism_full_precision_bart': data_parallelism_full_precision_bart, + 'data_parallelism_mixed_precision_bart': data_parallelism_mixed_precision_bart, 'distributed_zero_full_precision_adam': distributed_zero_full_precision_adam, 'distributed_zero_mixed_precision_adam': distributed_zero_mixed_precision_adam, 'distributed_zero_full_precision_lamb': distributed_zero_full_precision_lamb, - 'distributed_zero_mixed_precision_lamb': distributed_zero_mixed_precision_lamb + 'distributed_zero_mixed_precision_lamb': distributed_zero_mixed_precision_lamb, + 'distributed_zero_full_precision_lamb_bart': distributed_zero_full_precision_lamb_bart, + 'distributed_zero_mixed_precision_lamb_bart': distributed_zero_mixed_precision_lamb_bart, + 'distributed_megatron_full_precision_adam': distributed_megatron_full_precision_adam, + 'distributed_megatron_mixed_precision_adam': distributed_megatron_mixed_precision_adam, + 'distributed_megatron_full_precision_lamb': distributed_megatron_full_precision_lamb, + 'distributed_megatron_mixed_precision_lamb': distributed_megatron_mixed_precision_lamb, + 'distributed_zero_megatron_full_precision_adam': distributed_zero_megatron_full_precision_adam, + 'distributed_zero_megatron_mixed_precision_adam': distributed_zero_megatron_mixed_precision_adam, + 'distributed_zero_megatron_full_precision_lamb': distributed_zero_megatron_full_precision_lamb, + 'distributed_zero_megatron_mixed_precision_lamb': distributed_zero_megatron_mixed_precision_lamb } parser = argparse.ArgumentParser(description='Save states of trainers') parser.add_argument('--scenario', choices=function_map.keys(), help='training scenario to save states', required=True) diff --git a/orttraining/orttraining/test/python/orttraining_test_checkpoint.py b/orttraining/orttraining/test/python/orttraining_test_checkpoint.py index 917ac8ff3bbac..6d348a949a682 100644 --- a/orttraining/orttraining/test/python/orttraining_test_checkpoint.py +++ b/orttraining/orttraining/test/python/orttraining_test_checkpoint.py @@ -18,9 +18,9 @@ # - orttraining_test_save_checkpoint.py: responsible for saving all checkpoint files and trained states # - orttraining_test_load_checkpoint.py: loading the saved checkpoints and the saved states and asserting whether # the saved states match the loaded states. -# - and a total of 36 tests encompassing checkpointing tests: -# - from [onnxruntime orttrainer][full_precision, mixed_precision][single node training, data parallel training, distributed zero training] to -# [onnxruntime orttrainer, pytorch][full_precision, mixed_precision][single node training, data parallel training, distributed zero training] +# - and tests encompassing checkpointing tests for scenarios: +# - from [onnxruntime orttrainer][full_precision, mixed_precision][single node training, data parallel training, distributed zero, distributed megatron, distributed zero+megatron training] to +# [onnxruntime orttrainer, pytorch][full_precision, mixed_precision][single node training, data parallel training, distributed zero, distributed megatron, distributed zero+megatron training] # - all tests cannot be written in the same process because: # - some of them require to be run in a distributed environment (using mpirun) while others can be run using a single process. # - there is a known limitation where the distributed training run context is implemented as a singleton, so in the same process, no more than one @@ -66,6 +66,22 @@ distributed_zero_full_precision_lamb_path = os.path.join(checkpoint_dir, 'distributed_zero', 'full_precision', 'lamb') distributed_zero_mixed_precision_lamb_path = os.path.join(checkpoint_dir, 'distributed_zero', 'mixed_precision', 'lamb') +# megatron saving and loading uses a different model +single_node_full_precision_bart_path = os.path.join(checkpoint_dir, 'bart', 'single_node', 'full_precision') +single_node_mixed_precision_bart_path = os.path.join(checkpoint_dir, 'bart', 'single_node', 'mixed_precision') +data_parallelism_full_precision_bart_path = os.path.join(checkpoint_dir, 'bart', 'data_parallelism', 'full_precision') +data_parallelism_mixed_precision_bart_path = os.path.join(checkpoint_dir, 'bart', 'data_parallelism', 'mixed_precision') +distributed_zero_full_precision_lamb_bart_path = os.path.join(checkpoint_dir, 'bart', 'distributed_zero', 'full_precision', 'lamb') +distributed_zero_mixed_precision_lamb_bart_path = os.path.join(checkpoint_dir, 'bart', 'distributed_zero', 'mixed_precision', 'lamb') +distributed_megatron_full_precision_adam_path = os.path.join(checkpoint_dir, 'bart', 'distributed_megatron', 'full_precision', 'adam') +distributed_megatron_mixed_precision_adam_path = os.path.join(checkpoint_dir, 'bart', 'distributed_megatron', 'mixed_precision', 'adam') +distributed_megatron_full_precision_lamb_path = os.path.join(checkpoint_dir, 'bart', 'distributed_megatron', 'full_precision', 'lamb') +distributed_megatron_mixed_precision_lamb_path = os.path.join(checkpoint_dir, 'bart', 'distributed_megatron', 'mixed_precision', 'lamb') +distributed_zero_megatron_full_precision_adam_path = os.path.join(checkpoint_dir, 'bart', 'distributed_zero_megatron', 'full_precision', 'adam') +distributed_zero_megatron_mixed_precision_adam_path = os.path.join(checkpoint_dir, 'bart', 'distributed_zero_megatron', 'mixed_precision', 'adam') +distributed_zero_megatron_full_precision_lamb_path = os.path.join(checkpoint_dir, 'bart', 'distributed_zero_megatron', 'full_precision', 'lamb') +distributed_zero_megatron_mixed_precision_lamb_path = os.path.join(checkpoint_dir, 'bart', 'distributed_zero_megatron', 'mixed_precision', 'lamb') + # save all checkpoint files (pre-checkpoint) _single_run(save_checkpoint_file, 'single_node_full_precision', single_node_full_precision_path) _single_run(save_checkpoint_file, 'single_node_mixed_precision', single_node_mixed_precision_path) @@ -76,6 +92,22 @@ _distributed_run(save_checkpoint_file, 'distributed_zero_full_precision_lamb', distributed_zero_full_precision_lamb_path) _distributed_run(save_checkpoint_file, 'distributed_zero_mixed_precision_lamb', distributed_zero_mixed_precision_lamb_path) +_single_run(save_checkpoint_file, 'single_node_full_precision_bart', single_node_full_precision_bart_path) +_single_run(save_checkpoint_file, 'single_node_mixed_precision_bart', single_node_mixed_precision_bart_path) +_distributed_run(save_checkpoint_file, 'data_parallelism_full_precision_bart', data_parallelism_full_precision_bart_path) +_distributed_run(save_checkpoint_file, 'data_parallelism_mixed_precision_bart', data_parallelism_mixed_precision_bart_path) +_distributed_run(save_checkpoint_file, 'distributed_zero_full_precision_lamb_bart', distributed_zero_full_precision_lamb_bart_path) +_distributed_run(save_checkpoint_file, 'distributed_zero_mixed_precision_lamb_bart', distributed_zero_mixed_precision_lamb_bart_path) + +_distributed_run(save_checkpoint_file, 'distributed_megatron_full_precision_adam', distributed_megatron_full_precision_adam_path) +_distributed_run(save_checkpoint_file, 'distributed_megatron_mixed_precision_adam', distributed_megatron_mixed_precision_adam_path) +_distributed_run(save_checkpoint_file, 'distributed_megatron_full_precision_lamb', distributed_megatron_full_precision_lamb_path) +_distributed_run(save_checkpoint_file, 'distributed_megatron_mixed_precision_lamb', distributed_megatron_mixed_precision_lamb_path) +_distributed_run(save_checkpoint_file, 'distributed_zero_megatron_full_precision_adam', distributed_zero_megatron_full_precision_adam_path) +_distributed_run(save_checkpoint_file, 'distributed_zero_megatron_mixed_precision_adam', distributed_zero_megatron_mixed_precision_adam_path) +_distributed_run(save_checkpoint_file, 'distributed_zero_megatron_full_precision_lamb', distributed_zero_megatron_full_precision_lamb_path) +_distributed_run(save_checkpoint_file, 'distributed_zero_megatron_mixed_precision_lamb', distributed_zero_megatron_mixed_precision_lamb_path) + # load checkpoint files (post-checkpoint) # going to single node trainer _single_run(load_checkpoint_file, 'test_load_from_single_node_full_precision_into_single_node_full_precision', single_node_full_precision_path) @@ -90,6 +122,14 @@ _single_run(load_checkpoint_file, 'test_load_from_distributed_zero_mixed_precision_into_single_node_full_precision', distributed_zero_mixed_precision_lamb_path) _single_run(load_checkpoint_file, 'test_load_from_distributed_zero_mixed_precision_into_single_node_mixed_precision', distributed_zero_mixed_precision_lamb_path) _single_run(load_checkpoint_file, 'test_load_from_distributed_zero_full_precision_into_single_node_mixed_precision', distributed_zero_full_precision_lamb_path) +_single_run(load_checkpoint_file, 'test_load_from_distributed_megatron_full_precision_into_single_node_full_precision', distributed_megatron_full_precision_lamb_path) +_single_run(load_checkpoint_file, 'test_load_from_distributed_megatron_mixed_precision_into_single_node_full_precision', distributed_megatron_mixed_precision_lamb_path) +_single_run(load_checkpoint_file, 'test_load_from_distributed_megatron_mixed_precision_into_single_node_mixed_precision', distributed_megatron_mixed_precision_lamb_path) +_single_run(load_checkpoint_file, 'test_load_from_distributed_megatron_full_precision_into_single_node_mixed_precision', distributed_megatron_full_precision_lamb_path) +_single_run(load_checkpoint_file, 'test_load_from_distributed_zero_megatron_full_precision_into_single_node_full_precision', distributed_zero_megatron_full_precision_lamb_path) +_single_run(load_checkpoint_file, 'test_load_from_distributed_zero_megatron_mixed_precision_into_single_node_full_precision', distributed_zero_megatron_mixed_precision_lamb_path) +_single_run(load_checkpoint_file, 'test_load_from_distributed_zero_megatron_mixed_precision_into_single_node_mixed_precision', distributed_zero_megatron_mixed_precision_lamb_path) +_single_run(load_checkpoint_file, 'test_load_from_distributed_zero_megatron_full_precision_into_single_node_mixed_precision', distributed_zero_megatron_full_precision_lamb_path) # going to data parallel trainer _distributed_run(load_checkpoint_file, 'test_load_from_single_node_full_precision_into_data_parallelism_full_precision', single_node_full_precision_path) @@ -104,6 +144,14 @@ _distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_mixed_precision_into_data_parallelism_full_precision', distributed_zero_mixed_precision_lamb_path) _distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_mixed_precision_into_data_parallelism_mixed_precision', distributed_zero_mixed_precision_lamb_path) _distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_full_precision_into_data_parallelism_mixed_precision', distributed_zero_full_precision_lamb_path) +_distributed_run(load_checkpoint_file, 'test_load_from_distributed_megatron_full_precision_into_data_parallelism_full_precision', distributed_megatron_full_precision_lamb_path) +_distributed_run(load_checkpoint_file, 'test_load_from_distributed_megatron_mixed_precision_into_data_parallelism_full_precision', distributed_megatron_mixed_precision_lamb_path) +_distributed_run(load_checkpoint_file, 'test_load_from_distributed_megatron_mixed_precision_into_data_parallelism_mixed_precision', distributed_megatron_mixed_precision_lamb_path) +_distributed_run(load_checkpoint_file, 'test_load_from_distributed_megatron_full_precision_into_data_parallelism_mixed_precision', distributed_megatron_full_precision_lamb_path) +_distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_megatron_full_precision_into_data_parallelism_full_precision', distributed_zero_megatron_full_precision_lamb_path) +_distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_megatron_mixed_precision_into_data_parallelism_full_precision', distributed_zero_megatron_mixed_precision_lamb_path) +_distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_megatron_mixed_precision_into_data_parallelism_mixed_precision', distributed_zero_megatron_mixed_precision_lamb_path) +_distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_megatron_full_precision_into_data_parallelism_mixed_precision', distributed_zero_megatron_full_precision_lamb_path) # going to distributed zero trainer _distributed_run(load_checkpoint_file, 'test_load_from_single_node_full_precision_into_distributed_zero_full_precision', single_node_full_precision_path) @@ -118,12 +166,72 @@ _distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_mixed_precision_into_distributed_zero_full_precision', distributed_zero_mixed_precision_lamb_path) _distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_mixed_precision_into_distributed_zero_mixed_precision', distributed_zero_mixed_precision_lamb_path) _distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_full_precision_into_distributed_zero_mixed_precision', distributed_zero_full_precision_lamb_path) +_distributed_run(load_checkpoint_file, 'test_load_from_distributed_megatron_full_precision_into_distributed_zero_full_precision', distributed_megatron_full_precision_lamb_path) +_distributed_run(load_checkpoint_file, 'test_load_from_distributed_megatron_mixed_precision_into_distributed_zero_full_precision', distributed_megatron_mixed_precision_lamb_path) +_distributed_run(load_checkpoint_file, 'test_load_from_distributed_megatron_mixed_precision_into_distributed_zero_mixed_precision', distributed_megatron_mixed_precision_lamb_path) +_distributed_run(load_checkpoint_file, 'test_load_from_distributed_megatron_full_precision_into_distributed_zero_mixed_precision', distributed_megatron_full_precision_lamb_path) +_distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_megatron_full_precision_into_distributed_zero_full_precision', distributed_zero_megatron_full_precision_lamb_path) +_distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_zero_full_precision', distributed_zero_megatron_mixed_precision_lamb_path) +_distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_zero_mixed_precision', distributed_zero_megatron_mixed_precision_lamb_path) +_distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_megatron_full_precision_into_distributed_zero_mixed_precision', distributed_zero_megatron_full_precision_lamb_path) + +# going to distributed zero+megatron trainer +_distributed_run(load_checkpoint_file, 'test_load_from_single_node_full_precision_into_distributed_megatron_full_precision', single_node_full_precision_bart_path) +_distributed_run(load_checkpoint_file, 'test_load_from_single_node_mixed_precision_into_distributed_megatron_full_precision', single_node_mixed_precision_bart_path) +_distributed_run(load_checkpoint_file, 'test_load_from_single_node_mixed_precision_into_distributed_megatron_mixed_precision', single_node_mixed_precision_bart_path) +_distributed_run(load_checkpoint_file, 'test_load_from_single_node_full_precision_into_distributed_megatron_mixed_precision', single_node_full_precision_bart_path) +_distributed_run(load_checkpoint_file, 'test_load_from_data_parallelism_full_precision_into_distributed_megatron_full_precision', data_parallelism_full_precision_bart_path) +_distributed_run(load_checkpoint_file, 'test_load_from_data_parallelism_mixed_precision_into_distributed_megatron_full_precision', data_parallelism_mixed_precision_bart_path) +_distributed_run(load_checkpoint_file, 'test_load_from_data_parallelism_mixed_precision_into_distributed_megatron_mixed_precision', data_parallelism_mixed_precision_bart_path) +_distributed_run(load_checkpoint_file, 'test_load_from_data_parallelism_full_precision_into_distributed_megatron_mixed_precision', data_parallelism_full_precision_bart_path) +_distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_full_precision_into_distributed_megatron_full_precision', distributed_zero_full_precision_lamb_bart_path) +_distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_mixed_precision_into_distributed_megatron_full_precision', distributed_zero_mixed_precision_lamb_bart_path) +_distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_mixed_precision_into_distributed_megatron_mixed_precision', distributed_zero_mixed_precision_lamb_bart_path) +_distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_full_precision_into_distributed_megatron_mixed_precision', distributed_zero_full_precision_lamb_bart_path) +_distributed_run(load_checkpoint_file, 'test_load_from_distributed_megatron_full_precision_into_distributed_megatron_full_precision', distributed_megatron_full_precision_lamb_path) +_distributed_run(load_checkpoint_file, 'test_load_from_distributed_megatron_mixed_precision_into_distributed_megatron_full_precision', distributed_megatron_mixed_precision_lamb_path) +_distributed_run(load_checkpoint_file, 'test_load_from_distributed_megatron_mixed_precision_into_distributed_megatron_mixed_precision', distributed_megatron_mixed_precision_lamb_path) +_distributed_run(load_checkpoint_file, 'test_load_from_distributed_megatron_full_precision_into_distributed_megatron_mixed_precision', distributed_megatron_full_precision_lamb_path) +_distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_megatron_full_precision_into_distributed_megatron_full_precision', distributed_zero_megatron_full_precision_lamb_path) +_distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_megatron_full_precision', distributed_zero_megatron_mixed_precision_lamb_path) +_distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_megatron_mixed_precision', distributed_zero_megatron_mixed_precision_lamb_path) +_distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_megatron_full_precision_into_distributed_megatron_mixed_precision', distributed_zero_megatron_full_precision_lamb_path) + +# going to distributed zero+megatron trainer +_distributed_run(load_checkpoint_file, 'test_load_from_single_node_full_precision_into_distributed_zero_megatron_full_precision', single_node_full_precision_bart_path) +_distributed_run(load_checkpoint_file, 'test_load_from_single_node_mixed_precision_into_distributed_zero_megatron_full_precision', single_node_mixed_precision_bart_path) +_distributed_run(load_checkpoint_file, 'test_load_from_single_node_mixed_precision_into_distributed_zero_megatron_mixed_precision', single_node_mixed_precision_bart_path) +_distributed_run(load_checkpoint_file, 'test_load_from_single_node_full_precision_into_distributed_zero_megatron_mixed_precision', single_node_full_precision_bart_path) +_distributed_run(load_checkpoint_file, 'test_load_from_data_parallelism_full_precision_into_distributed_zero_megatron_full_precision', data_parallelism_full_precision_bart_path) +_distributed_run(load_checkpoint_file, 'test_load_from_data_parallelism_mixed_precision_into_distributed_zero_megatron_full_precision', data_parallelism_mixed_precision_bart_path) +_distributed_run(load_checkpoint_file, 'test_load_from_data_parallelism_mixed_precision_into_distributed_zero_megatron_mixed_precision', data_parallelism_mixed_precision_bart_path) +_distributed_run(load_checkpoint_file, 'test_load_from_data_parallelism_full_precision_into_distributed_zero_megatron_mixed_precision', data_parallelism_full_precision_bart_path) +_distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_full_precision_into_distributed_zero_megatron_full_precision', distributed_zero_full_precision_lamb_bart_path) +_distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_mixed_precision_into_distributed_zero_megatron_full_precision', distributed_zero_mixed_precision_lamb_bart_path) +_distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_mixed_precision_into_distributed_zero_megatron_mixed_precision', distributed_zero_mixed_precision_lamb_bart_path) +_distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_full_precision_into_distributed_zero_megatron_mixed_precision', distributed_zero_full_precision_lamb_bart_path) +_distributed_run(load_checkpoint_file, 'test_load_from_distributed_megatron_full_precision_into_distributed_zero_megatron_full_precision', distributed_megatron_full_precision_lamb_path) +_distributed_run(load_checkpoint_file, 'test_load_from_distributed_megatron_mixed_precision_into_distributed_zero_megatron_full_precision', distributed_megatron_mixed_precision_lamb_path) +_distributed_run(load_checkpoint_file, 'test_load_from_distributed_megatron_mixed_precision_into_distributed_zero_megatron_mixed_precision', distributed_megatron_mixed_precision_lamb_path) +_distributed_run(load_checkpoint_file, 'test_load_from_distributed_megatron_full_precision_into_distributed_zero_megatron_mixed_precision', distributed_megatron_full_precision_lamb_path) +_distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_megatron_full_precision_into_distributed_zero_megatron_full_precision', distributed_zero_megatron_full_precision_lamb_path) +_distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_zero_megatron_full_precision', distributed_zero_megatron_mixed_precision_lamb_path) +_distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_zero_megatron_mixed_precision', distributed_zero_megatron_mixed_precision_lamb_path) +_distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_megatron_full_precision_into_distributed_zero_megatron_mixed_precision', distributed_zero_megatron_full_precision_lamb_path) # checkpoint aggregation tests _single_run(aggregate_checkpoint_file, 'test_aggregation_from_distributed_zero_full_precision_adam', distributed_zero_full_precision_adam_path) _single_run(aggregate_checkpoint_file, 'test_aggregation_from_distributed_zero_mixed_precision_adam', distributed_zero_mixed_precision_adam_path) _single_run(aggregate_checkpoint_file, 'test_aggregation_from_distributed_zero_mixed_precision_lamb', distributed_zero_mixed_precision_lamb_path) _single_run(aggregate_checkpoint_file, 'test_aggregation_from_distributed_zero_full_precision_lamb', distributed_zero_full_precision_lamb_path) +_single_run(aggregate_checkpoint_file, 'test_aggregation_from_distributed_megatron_full_precision_adam', distributed_megatron_full_precision_adam_path) +_single_run(aggregate_checkpoint_file, 'test_aggregation_from_distributed_megatron_mixed_precision_adam', distributed_megatron_mixed_precision_adam_path) +_single_run(aggregate_checkpoint_file, 'test_aggregation_from_distributed_megatron_mixed_precision_lamb', distributed_megatron_mixed_precision_lamb_path) +_single_run(aggregate_checkpoint_file, 'test_aggregation_from_distributed_megatron_full_precision_lamb', distributed_megatron_full_precision_lamb_path) +_single_run(aggregate_checkpoint_file, 'test_aggregation_from_distributed_zero_megatron_full_precision_adam', distributed_zero_megatron_full_precision_adam_path) +_single_run(aggregate_checkpoint_file, 'test_aggregation_from_distributed_zero_megatron_mixed_precision_adam', distributed_zero_megatron_mixed_precision_adam_path) +_single_run(aggregate_checkpoint_file, 'test_aggregation_from_distributed_zero_megatron_mixed_precision_lamb', distributed_zero_megatron_mixed_precision_lamb_path) +_single_run(aggregate_checkpoint_file, 'test_aggregation_from_distributed_zero_megatron_full_precision_lamb', distributed_zero_megatron_full_precision_lamb_path) # optimizer state loading into model-parallel tests _distributed_run(optim_state_file, 'test_optim_load_to_distributed_zero_full_precision_adam', distributed_zero_full_precision_adam_path) diff --git a/orttraining/orttraining/test/python/orttraining_test_orttrainer_checkpoint_functions.py b/orttraining/orttraining/test/python/orttraining_test_orttrainer_checkpoint_functions.py index 1434deb588ee6..ae67b7250039e 100644 --- a/orttraining/orttraining/test/python/orttraining_test_orttrainer_checkpoint_functions.py +++ b/orttraining/orttraining/test/python/orttraining_test_orttrainer_checkpoint_functions.py @@ -21,6 +21,8 @@ def _create_trainer(zero_enabled=False): opts['distributed'] = { 'world_rank' : 0, 'world_size' : 1, + 'horizontal_parallel_size' : 1, + 'data_parallel_size' : 1, 'allreduce_post_accumulation' : True, 'deepspeed_zero_optimization': { @@ -489,6 +491,8 @@ def test_load_checkpoint(aggregate_checkpoints_mock, load_mock): 'mixed_precision': np.bool_(False), 'world_rank': np.int64(0), 'world_size': np.int64(1), + 'horizontal_parallel_size' : np.int64(1), + 'data_parallel_size' : np.int64(1), 'zero_stage': np.int64(0) } state_dict = { @@ -498,6 +502,8 @@ def test_load_checkpoint(aggregate_checkpoints_mock, load_mock): 'mixed_precision': np.bool_(False), 'world_rank': np.int64(0), 'world_size': np.int64(1), + 'horizontal_parallel_size' : np.int64(1), + 'data_parallel_size' : np.int64(1), 'zero_stage': np.int64(0) } } @@ -522,18 +528,24 @@ def test_load_checkpoint(aggregate_checkpoints_mock, load_mock): 'mixed_precision': np.bool_(False), 'world_rank': np.int64(0), 'world_size': np.int64(4), + 'horizontal_parallel_size' : np.int64(1), + 'data_parallel_size' : np.int64(4), 'zero_stage': np.int64(1) }, { 'mixed_precision': np.bool_(True), 'world_rank': np.int64(0), 'world_size': np.int64(1), + 'horizontal_parallel_size' : np.int64(1), + 'data_parallel_size' : np.int64(1), 'zero_stage': np.int64(1) }, { 'mixed_precision': np.bool_(True), 'world_rank': np.int64(0), 'world_size': np.int64(1), + 'horizontal_parallel_size' : np.int64(1), + 'data_parallel_size' : np.int64(1), 'zero_stage': np.int64(1) } ]) @@ -560,6 +572,8 @@ def test_load_checkpoint_user_dict(aggregate_checkpoints_mock, load_mock): 'mixed_precision': np.bool_(False), 'world_rank': np.int64(0), 'world_size': np.int64(1), + 'horizontal_parallel_size': np.int64(1), + 'data_parallel_size': np.int64(1), 'zero_stage': np.int64(0) } state_dict = { @@ -569,6 +583,8 @@ def test_load_checkpoint_user_dict(aggregate_checkpoints_mock, load_mock): 'mixed_precision': np.bool_(False), 'world_rank': np.int64(0), 'world_size': np.int64(1), + 'horizontal_parallel_size': np.int64(1), + 'data_parallel_size': np.int64(1), 'zero_stage': np.int64(0) }, 'user_dict': _checkpoint_storage.to_serialized_hex({'array': torch.tensor(np.arange(5))}) @@ -586,6 +602,8 @@ def test_checkpoint_aggregation(load_mock): 'mixed_precision': np.bool_(False), 'world_rank': np.int64(0), 'world_size': np.int64(2), + 'horizontal_parallel_size' : np.int64(1), + 'data_parallel_size' : np.int64(2), 'zero_stage': np.int64(1), 'optimizer_name': b'Adam' } @@ -593,6 +611,8 @@ def test_checkpoint_aggregation(load_mock): 'mixed_precision': np.bool_(False), 'world_rank': np.int64(1), 'world_size': np.int64(2), + 'horizontal_parallel_size' : np.int64(1), + 'data_parallel_size' : np.int64(2), 'zero_stage': np.int64(1), 'optimizer_name': b'Adam' } @@ -620,6 +640,8 @@ def test_checkpoint_aggregation(load_mock): 'mixed_precision': np.bool_(False), 'world_rank': np.int64(0), 'world_size': np.int64(1), + 'horizontal_parallel_size' : np.int64(1), + 'data_parallel_size' : np.int64(1), 'zero_stage': np.int64(0), 'optimizer_name': b'Adam' }, @@ -651,6 +673,8 @@ def test_checkpoint_aggregation(load_mock): 'mixed_precision': np.bool_(False), 'world_rank': np.int64(1), 'world_size': np.int64(1), + 'horizontal_parallel_size' : np.int64(1), + 'data_parallel_size' : np.int64(1), 'zero_stage': np.int64(0), 'optimizer_name': b'Adam' }, @@ -659,7 +683,7 @@ def test_checkpoint_aggregation(load_mock): } } - load_mock.side_effect = [trainer_options1, trainer_options2, state_dict1, state_dict2] + load_mock.side_effect = [trainer_options1, trainer_options2, trainer_options1, state_dict1, state_dict2] state_dict = checkpoint.aggregate_checkpoints(['abc', 'def'], pytorch_format=False) assert (state_dict['model']['full_precision']['optimizer_sharded'] == np.array([1, 2, 3])).all() @@ -674,6 +698,8 @@ def test_checkpoint_aggregation(load_mock): assert state_dict['trainer_options']['mixed_precision'] == False assert state_dict['trainer_options']['world_rank'] == 0 assert state_dict['trainer_options']['world_size'] == 1 + assert state_dict['trainer_options']['horizontal_parallel_size'] == 1 + assert state_dict['trainer_options']['data_parallel_size'] == 1 assert state_dict['trainer_options']['zero_stage'] == 0 assert state_dict['trainer_options']['optimizer_name'] == b'Adam' @@ -683,6 +709,8 @@ def test_checkpoint_aggregation_mixed_precision(load_mock): 'mixed_precision': np.bool_(True), 'world_rank': np.int64(0), 'world_size': np.int64(2), + 'horizontal_parallel_size': np.int64(1), + 'data_parallel_size': np.int64(2), 'zero_stage': np.int64(1), 'optimizer_name': b'Adam' } @@ -690,6 +718,8 @@ def test_checkpoint_aggregation_mixed_precision(load_mock): 'mixed_precision': np.bool_(True), 'world_rank': np.int64(1), 'world_size': np.int64(2), + 'horizontal_parallel_size': np.int64(1), + 'data_parallel_size': np.int64(2), 'zero_stage': np.int64(1), 'optimizer_name': b'Adam' } @@ -717,6 +747,8 @@ def test_checkpoint_aggregation_mixed_precision(load_mock): 'mixed_precision': np.bool_(True), 'world_rank': np.int64(0), 'world_size': np.int64(1), + 'horizontal_parallel_size': np.int64(1), + 'data_parallel_size': np.int64(1), 'zero_stage': np.int64(0), 'optimizer_name': b'Adam' }, @@ -748,6 +780,8 @@ def test_checkpoint_aggregation_mixed_precision(load_mock): 'mixed_precision': np.bool_(True), 'world_rank': np.int64(1), 'world_size': np.int64(1), + 'horizontal_parallel_size': np.int64(1), + 'data_parallel_size': np.int64(1), 'zero_stage': np.int64(0), 'optimizer_name': b'Adam' }, @@ -756,7 +790,7 @@ def test_checkpoint_aggregation_mixed_precision(load_mock): } } - load_mock.side_effect = [trainer_options1, trainer_options2, state_dict1, state_dict2] + load_mock.side_effect = [trainer_options1, trainer_options2, trainer_options1, state_dict1, state_dict2] state_dict = checkpoint.aggregate_checkpoints(['abc', 'def'], pytorch_format=False) assert (state_dict['model']['full_precision']['sharded'] == np.array([[1, 2, 3], [4, 5, 6]])).all() @@ -771,5 +805,7 @@ def test_checkpoint_aggregation_mixed_precision(load_mock): assert state_dict['trainer_options']['mixed_precision'] == True assert state_dict['trainer_options']['world_rank'] == 0 assert state_dict['trainer_options']['world_size'] == 1 + assert state_dict['trainer_options']['horizontal_parallel_size'] == 1 + assert state_dict['trainer_options']['data_parallel_size'] == 1 assert state_dict['trainer_options']['zero_stage'] == 0 assert state_dict['trainer_options']['optimizer_name'] == b'Adam' diff --git a/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py b/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py index 6af5a240d2948..c0c8c8a89daec 100644 --- a/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py +++ b/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py @@ -60,6 +60,8 @@ def testORTTrainerOptionsDefaultValues(test_input): 'sliced_tensor_names': [] }, 'allreduce_post_accumulation': False, + 'data_parallel_size': 1, + 'horizontal_parallel_size':1, 'deepspeed_zero_optimization': { 'stage' : 0, }, diff --git a/tools/ci_build/github/azure-pipelines/orttraining-linux-gpu-distributed-test-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/orttraining-linux-gpu-distributed-test-ci-pipeline.yml index 743306bc9a0d1..d79069d6ef603 100644 --- a/tools/ci_build/github/azure-pipelines/orttraining-linux-gpu-distributed-test-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/orttraining-linux-gpu-distributed-test-ci-pipeline.yml @@ -40,7 +40,7 @@ jobs: --cwd /build/RelWithDebInfo \ displayName: 'Run orttraining_distributed_tests.py' condition: succeededOrFailed() - timeoutInMinutes: 30 + timeoutInMinutes: 60 - template: templates/component-governance-component-detection-steps.yml parameters: