Skip to content

Commit

Permalink
Megatron checkpointing (microsoft#6293)
Browse files Browse the repository at this point in the history
* Add bart fairseq run script

* Add frontend change to enable megatron

* Initial changes for checkpointing

* Megatron optim state loading, checkpoint aggregation, frontend distributed tests for H, D+H

* Add load_checkpoint changes

* Fix CI

* Cleanup

* Fix CI

* review comments

* review comments

* review comments:
  • Loading branch information
ashbhandare authored Jan 22, 2021
1 parent 4442d94 commit 60c772e
Show file tree
Hide file tree
Showing 21 changed files with 3,576 additions and 228 deletions.
Binary file added onnxruntime/test/testdata/bart_tiny.onnx
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ Status AdamOptimizerBuilder::Build(
const std::string& gradient_name = gradient_argdefs[i].name;
const TypeProto* const weight_type_proto = weight_argdefs[i].type_proto;
const TypeProto* const gradient_type_proto = gradient_argdefs[i].type_proto;
weight_to_opt_mapping[weight_name] = {};

// Return either the input gradient/weight/mixed-precision-weight or updated gradient/weight/mixed-precision-weight.
ArgDef output_gradient_argdef = gradient_argdefs[i];
Expand All @@ -38,6 +37,7 @@ Status AdamOptimizerBuilder::Build(

// In distributed training, some weights may not be updated by all ranks.
if (opt_configs[i].enabled) {
weight_to_opt_mapping[weight_name] = {};
// The type proto initializer for Update Count
const std::string update_count_string = ADAM_UC_PREFIX + "_" + weight_name; // per weight optimizer requires a per weight update count
TensorProto uc_tensor_proto;
Expand Down Expand Up @@ -82,7 +82,6 @@ Status AdamOptimizerBuilder::Build(
const auto element_type = opt_configs[i].use_mixed_precision_moments ?
ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16 :
ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT;

// Add first- and second-order momentums to input list.
for (const auto& moments_prefix : MOMENTS_PREFIXES) {
const std::string gradient_moment_name = moments_prefix + "_" + weight_name;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ static Status AddParameterPartition(
//Partition the FP32 weight
weight_views = AddPartitionsForParameter(graph, graph_defs, weight_argdef.name, view_shapes, updated_weight_names_map);
ORT_ENFORCE(weight_views.size() == enabled.size());
weight_partition_info[weight_argdef.name].weight_partitioned = true;

// Add View for mixed precision weight.
ArgDef mixed_precision_weight_argdef(opt_config.mixed_precision_weight_arg->Name(), opt_config.mixed_precision_weight_arg->TypeAsProto());
Expand All @@ -275,7 +276,7 @@ static Status AddParameterPartition(

// Partition initial optimizer state
if (enabled[i]) {
weight_partition_info[weight_argdef.name].view_name = weight_views[i].name;
weight_partition_info[weight_argdef.name].partition_name = weight_views[i].name;

if (!initial_states.empty()) {
ORT_ENFORCE(view_shapes.size() == 3, "Invalid view_shapes vector passed for partitioning.");
Expand Down
228 changes: 191 additions & 37 deletions orttraining/orttraining/core/optimizer/megatron_transformer.cc

Large diffs are not rendered by default.

16 changes: 15 additions & 1 deletion orttraining/orttraining/core/optimizer/megatron_transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,17 @@ class MegatronTransformer : public GraphTransformer {
std::unordered_map<std::string, std::string>& updated_weight_names,
std::unordered_set<std::string>& weights_to_train,
std::unordered_map<std::string, training::TrainingSession::PartitionInfo>& weight_partition_info,
training::TrainingSession::OptimizerState& initial_optimizer_states,
const IExecutionProvider& cpu_execution_provider , // Required to get allocator for optimizer partitioning by Col
const std::unordered_set<std::string>& compatible_execution_providers = {}) noexcept
: GraphTransformer("MegatronTransformer", compatible_execution_providers),
horizontal_parallel_rank_(horizontal_parallel_rank),
horizontal_parallel_size_(horizontal_parallel_size),
updated_weight_names_(updated_weight_names),
weights_to_train_(weights_to_train),
weight_partition_info_(weight_partition_info) {}
weight_partition_info_(weight_partition_info),
initial_optimizer_states_(initial_optimizer_states),
cpu_execution_provider_ (cpu_execution_provider ) {}

Status ApplyImpl(Graph& graph, bool& modified, int graph_level,
const logging::Logger& logger) const override;
Expand Down Expand Up @@ -63,6 +67,14 @@ class MegatronTransformer : public GraphTransformer {
std::unordered_set<Node*>& dropout_nodes_to_transform,
int32_t& counter) const;

template <class T>
void PartitionBufferByColumn(const T* input,
const int64_t row_count,
const int64_t column_count,
const int64_t column_stride,
const int stride,
std::vector<T>& result) const;

bool PartitionWeightByColumn(const Graph& graph, const NodeArg& input_arg,
ONNX_NAMESPACE::TensorProto& initializer_partition,
int stride = 1) const;
Expand All @@ -75,6 +87,8 @@ class MegatronTransformer : public GraphTransformer {
std::unordered_map<std::string, std::string>& updated_weight_names_;
std::unordered_set<std::string>& weights_to_train_;
std::unordered_map<std::string, training::TrainingSession::PartitionInfo>& weight_partition_info_;
training::TrainingSession::OptimizerState& initial_optimizer_states_;
const IExecutionProvider& cpu_execution_provider_ ;
};

} // namespace onnxruntime
72 changes: 45 additions & 27 deletions orttraining/orttraining/core/session/training_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ Status SetupOptimizerParams(
const std::unordered_map<std::string, NodeArg*>& fp32_weight_names_to_mixed_precision_node_args,
const optional<std::string>& loss_scale_input_name,
const TrainingSession::TrainingConfiguration& config,
const TrainingSession::OptimizerState& init_optimizer_states,
OptimizerGraphConfig& opt_graph_config_result,
std::unordered_map<std::string, OptimizerNodeConfig>& opt_node_configs_result,
std::unordered_map<std::string, std::string>& weight_name_map_after_graph_transform) {
Expand Down Expand Up @@ -98,12 +99,10 @@ Status SetupOptimizerParams(
opt_node_config.mixed_precision_weight_arg = mixed_precision_weight_name_it->second;
}

// check if initial optimizer states have been provided for weight
if (config.init_optimizer_states) {
const auto optim_state_it = config.init_optimizer_states->find(weight_name);
if (optim_state_it != config.init_optimizer_states->end()) {
opt_node_config.initial_states = optim_state_it->second;
}
// retrieve value for initial optimizer states if provided for weight
const auto optim_state_it = init_optimizer_states.find(original_weight_name);
if (optim_state_it != init_optimizer_states.end()) {
opt_node_config.initial_states = optim_state_it->second;
}

opt_node_configs.emplace(weight_name, std::move(opt_node_config));
Expand All @@ -130,12 +129,11 @@ Status SetupOptimizerParams(
opt_graph_config.deepspeed_zero = optimizer_config.deepspeed_zero;

// check if shared initial optimizer states have been provided
if (config.init_optimizer_states) {
const auto optim_state_it = config.init_optimizer_states->find(onnxruntime::training::SHARED_OPTIMIZER_STATES_KEY);
if (optim_state_it != config.init_optimizer_states->end()) {
opt_graph_config.shared_optimizer_states = std::move(optim_state_it->second);
}
const auto optim_state_it = init_optimizer_states.find(onnxruntime::training::SHARED_OPTIMIZER_STATES_KEY);
if (optim_state_it != init_optimizer_states.end()) {
opt_graph_config.shared_optimizer_states = std::move(optim_state_it->second);
}

opt_node_configs_result = std::move(opt_node_configs);
opt_graph_config_result = std::move(opt_graph_config);

Expand Down Expand Up @@ -420,6 +418,10 @@ Status TrainingSession::ConfigureForTraining(
}
}

if (config.init_optimizer_states) {
init_optimizer_states_ = config.init_optimizer_states.value();
}

ORT_RETURN_IF_ERROR(ApplyTransformationsToMainGraph(trainable_initializers, config.graph_transformer_config));

ORT_RETURN_IF_ERROR(ApplyModelParallelTransformationsToMainGraph(trainable_initializers, config_result));
Expand Down Expand Up @@ -501,7 +503,7 @@ Status TrainingSession::ConfigureForTraining(
std::unordered_map<std::string, OptimizerNodeConfig> opt_node_configs{};
ORT_RETURN_IF_ERROR(SetupOptimizerParams(
weights_to_train_, fp32_weight_name_to_mixed_precision_node_arg,
loss_scale_input_name, config, opt_graph_config, opt_node_configs, config_result.weight_name_map_after_graph_transform));
loss_scale_input_name, config, init_optimizer_states_, opt_graph_config, opt_node_configs, config_result.weight_name_map_after_graph_transform));
TrainingConfigurationResult::OptimizerConfigurationResult optimizer_config_result{};
ORT_RETURN_IF_ERROR(BuildOptimizer(
opt_graph_config, opt_node_configs,
Expand Down Expand Up @@ -796,12 +798,16 @@ Status TrainingSession::ApplyModelParallelTransformationsToMainGraph(std::unorde

GraphTransformerManager graph_transformation_mgr{1};
std::vector<std::unique_ptr<GraphTransformer>> transformers_to_register;
// Creating the CPU EP here to be used to get the
// CPU allocator for partitioning the optimizer state by column.
std::unique_ptr<CPUExecutionProvider> cpu_execution_provider =
onnxruntime::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
std::unordered_set<std::string> compatible_eps = {};
LOGS_DEFAULT(WARNING) << horizontal_parallel_size << "-way horizontal model parallel is enabled";
transformers_to_register.emplace_back(onnxruntime::make_unique<MegatronTransformer>(
training::DistributedRunContext::RankInGroup(training::WorkerGroupType::HorizontalParallel),
horizontal_parallel_size, config_result_out.weight_name_map_after_graph_transform, weights_to_train,
config_result_out.weight_partition_info, compatible_eps));
config_result_out.weight_partition_info, init_optimizer_states_, *cpu_execution_provider, compatible_eps));

// Generate and register transformers for level
for (auto& entry : transformers_to_register) {
Expand Down Expand Up @@ -1082,10 +1088,13 @@ common::Status TrainingSession::GetOptimizerState(std::unordered_map<std::string
}
// Change key from sharded_name to weight_name using partition_info
for (const auto& weight : weight_partition_info_) {
const auto& it = opt_state_tensors.find(weight.second.view_name);
ORT_ENFORCE(it != opt_state_tensors.end(), "Cannot find weight: " + weight.second.view_name + " in weight_partition_info_");
opt_state_tensors[weight.first] = it->second;
opt_state_tensors.erase(it);
const auto& it = opt_state_tensors.find(weight.second.partition_name);
if (it == opt_state_tensors.end()) {
ORT_RETURN_IF_NOT(allow_missing, "Failed to get optimizer params for partition: " + weight.second.partition_name);
} else {
opt_state_tensors[weight.first] = it->second;
opt_state_tensors.erase(it);
}
}
return Status::OK();
}
Expand All @@ -1095,20 +1104,29 @@ common::Status TrainingSession::GetModelState(std::unordered_map<std::string, Na
std::unordered_set<std::string> fp_tensor_names{};
fp_tensor_names.insert(
weights_to_train_.begin(), weights_to_train_.end());
// Add zero sharded weights, only needed for fp32 weights in mixed precision run
for (const auto& weight_sharded_pair : updated_weight_names_map_) {
fp_tensor_names.erase(weight_sharded_pair.first); // remove the original name
fp_tensor_names.insert(weight_sharded_pair.second);
// Add sharded weights
for (const auto& weight : weight_partition_info_) {
if (weight.second.weight_partitioned) {
fp_tensor_names.erase(weight.first); // remove the original name
fp_tensor_names.insert(weight.second.partition_name);
}
}

NameMLValMap fp_weights;
GetSessionState().GetInitializedTensors(fp_tensor_names, allow_missing, fp_weights);
// Change key from sharded_name to weight_name
for (const auto& weight_sharded_pair : updated_weight_names_map_) {
const auto& it = fp_weights.find(weight_sharded_pair.second);
ORT_ENFORCE(it != fp_weights.end(), "Cannot find weight: " + weight_sharded_pair.second + " in updated_weight_names_map_");
fp_weights[weight_sharded_pair.first] = it->second;
fp_weights.erase(it);
// Change key from sharded_name to weight_name using partition_info
for (const auto& weight : weight_partition_info_) {
if (weight.second.weight_partitioned) {
const auto& it = fp_weights.find(weight.second.partition_name);
if (it == fp_weights.end()) {
ORT_RETURN_IF_NOT(allow_missing, "Failed to get weight partition: " + weight.second.partition_name);
} else {
fp_weights[weight.first] = it->second;
fp_weights.erase(it);
}
}
}

model_state_tensors["full_precision"] = fp_weights;
if (include_mixed_precision_weights) {
std::unordered_set<std::string> mp_tensor_names{};
Expand Down
9 changes: 8 additions & 1 deletion orttraining/orttraining/core/session/training_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,15 @@ class TrainingSession : public InferenceSession {
* Partition information of each paritioned weight
*/
struct PartitionInfo {
// value of the original shape of the weight
std::vector<int64_t> original_dim;
// indicates whether weight was megatron partitioned or not.
// -1: not partitioned; 0: column partitioned; 1: row partitioned
int megatron_row_partition = -1;
std::string view_name;
// name of the partition used to look up partitioned weight and optimizer state values
std::string partition_name;
// whether the weight itself was paritioned or not(eg:just the optimizer state for fp32 Zero-1)
bool weight_partitioned = false;
};

TrainingSession(const SessionOptions& session_options, const Environment& env)
Expand Down Expand Up @@ -551,6 +557,7 @@ class TrainingSession : public InferenceSession {
bool is_configured_{false};

std::unordered_set<std::string> weights_to_train_;
OptimizerState init_optimizer_states_;
// names of additional initializers to be included in checkpoints
std::unordered_map<std::string, std::string> updated_weight_names_map_;
std::unordered_set<std::string> opt_state_initializer_names_;
Expand Down
10 changes: 10 additions & 0 deletions orttraining/orttraining/python/training/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,16 @@ def state_dict_trainer_options_world_size_key():

return 'world_size'

def state_dict_trainer_options_data_parallel_size_key():
"""Returns the trainer options data_parallel_size key name in the state dictionary"""

return 'data_parallel_size'

def state_dict_trainer_options_horizontal_parallel_size_key():
"""Returns the trainer options horizontal_parallel_size key name in the state dictionary"""

return 'horizontal_parallel_size'

def state_dict_trainer_options_optimizer_name_key():
"""Returns the trainer options optimizer_name key name in the state dictionary"""

Expand Down
Loading

0 comments on commit 60c772e

Please sign in to comment.