Skip to content

Commit

Permalink
Add max_norm for gradient clipping. (microsoft#6289)
Browse files Browse the repository at this point in the history
* add max_norm as user option for gradient clipping

* add adam and lamb test cases for clip norm

* add frontend tests
  • Loading branch information
pengwa authored Jan 20, 2021
1 parent a1b5bfc commit 453431f
Show file tree
Hide file tree
Showing 23 changed files with 649 additions and 233 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,10 @@ static Status AddNcclAllReduceForGradientsWithGroups(
allreduce_outputs[i] = ArgDef(gradient_argdefs[i].name + "_AllReduce_Out", allreduced_gradient_type_proto);
}
graph_defs.AddNodeDefs({NodeDef(OpDef{"View", kMSDomain, 1},
view_inputs,
allreduce_outputs,
NodeAttributes(),
"AllReduceOutputView")});
view_inputs,
allreduce_outputs,
NodeAttributes(),
"AllReduceOutputView")});

gradient_argdefs = allreduce_outputs;
return Status::OK();
Expand All @@ -153,7 +153,7 @@ static Status AddAdasumAllReduceForGradients(
gradient_argdefs,
adasum_output_argdefs,
{ONNX_NAMESPACE::MakeAttribute("reduce_algo",
static_cast<int64_t>(adasum_reduction_type))},
static_cast<int64_t>(adasum_reduction_type))},
"AdasumAllReduce")});
gradient_argdefs = std::move(adasum_output_argdefs);
return Status::OK();
Expand All @@ -168,7 +168,6 @@ Status AdasumOptimizerGraphBuilder::BuildInternal(
std::vector<ArgDef>& gradient_argdefs,
std::unordered_map<std::string, std::unordered_map<std::string, std::string>>& weight_to_opt_mapping,
OptimizerOutputKeyMap<std::string>& optimizer_graph_outputs) {

// Set weight update to false for optimizer
for (auto& opt_config : opt_configs_) {
opt_config.update_weight = false;
Expand All @@ -191,7 +190,7 @@ Status AdasumOptimizerGraphBuilder::BuildInternal(

const float scale = 1.0f / scale_divisor;
// Only fuse if using hierarchical reduce.
const bool fuse_scaling_outputs = opt_graph_config_.adasum_reduction_type == AdasumReductionType::GpuHierarchicalReduction ? true: false;
const bool fuse_scaling_outputs = opt_graph_config_.adasum_reduction_type == AdasumReductionType::GpuHierarchicalReduction ? true : false;
ORT_RETURN_IF_ERROR(AddGradientScalingNodes(nodearg_name_generator, scale, gradient_argdefs, fused_gradient_argdef, graph_defs,
opt_graph_config_.AllReduceDataType(), fuse_scaling_outputs));

Expand All @@ -200,7 +199,7 @@ Status AdasumOptimizerGraphBuilder::BuildInternal(
if (opt_graph_config_.adasum_reduction_type == AdasumReductionType::GpuHierarchicalReduction) {
#ifdef ORT_USE_NCCL
ORT_RETURN_IF_ERROR(AddNcclAllReduceForGradientsWithGroups(gradient_argdefs, fused_gradient_argdef, graph_defs,
reduced_fused_gradient_argdef, WorkerGroupType::NodeLocalDataParallel));
reduced_fused_gradient_argdef, WorkerGroupType::NodeLocalDataParallel));
#else
ORT_THROW("ORT is not built with NCCL.");
#endif
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class AdamOptimizerBuilder final : public OptimizerBuilder {
"beta",
"lambda",
"epsilon",
"max_norm_clip",
"do_bias_correction",
"weight_decay_mode"}) {}

Expand All @@ -25,7 +26,6 @@ class AdamOptimizerBuilder final : public OptimizerBuilder {
std::unordered_map<std::string, std::unordered_map<std::string, std::string>>& weight_to_opt_mapping,
std::vector<ArgDef>& output_weight_argdefs,
std::vector<ArgDef>& output_gradient_argdefs) const override;

};

} // namespace training
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ Status LambOptimizerBuilder::Build(
std::vector<float> beta;
std::vector<float> lambda;
std::vector<float> epsilon;
std::vector<float> max_norm_clip;
float ratio_min = -std::numeric_limits<float>::infinity();
float ratio_max = std::numeric_limits<float>::infinity();
int64_t do_bias_correction = 0;
Expand Down Expand Up @@ -157,6 +158,12 @@ Status LambOptimizerBuilder::Build(
else
epsilon.emplace_back(1e-6f);

auto max_norm_clip_iter = attrs.find("max_norm_clip");
if (max_norm_clip_iter != attrs.end())
max_norm_clip.emplace_back(max_norm_clip_iter->second);
else
max_norm_clip.emplace_back(1.0f);

auto ratio_min_iter = attrs.find("ratio_min");
if (ratio_min_iter != attrs.end()) {
// All weight tensors should have the same min ratio.
Expand Down Expand Up @@ -202,9 +209,7 @@ Status LambOptimizerBuilder::Build(
output_argdefs.push_back(output_gradient_argdef); // g_new
}

const auto element_type = opt_configs[i].use_mixed_precision_moments ?
ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16 :
ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT;
const auto element_type = opt_configs[i].use_mixed_precision_moments ? ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16 : ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT;

weight_to_opt_mapping[weight_name] = {};
// m1 & m2 & m1_new & m2_new
Expand Down Expand Up @@ -245,11 +250,11 @@ Status LambOptimizerBuilder::Build(
// w_mixed_precision & w_mixed_precision_new
if (opt_configs[i].update_weight && opt_configs[i].mixed_precision_weight_arg != nullptr) {
input_argdefs.emplace_back(ArgDef(
opt_configs[i].mixed_precision_weight_arg->Name(),
opt_configs[i].mixed_precision_weight_arg->TypeAsProto()));
opt_configs[i].mixed_precision_weight_arg->Name(),
opt_configs[i].mixed_precision_weight_arg->TypeAsProto()));
output_weight_argdef = ArgDef(
opt_configs[i].mixed_precision_weight_arg->Name() + "_Lamb_out",
opt_configs[i].mixed_precision_weight_arg->TypeAsProto());
opt_configs[i].mixed_precision_weight_arg->Name() + "_Lamb_out",
opt_configs[i].mixed_precision_weight_arg->TypeAsProto());
output_argdefs.push_back(output_weight_argdef);
} else {
input_argdefs.emplace_back(ArgDef());
Expand All @@ -266,6 +271,7 @@ Status LambOptimizerBuilder::Build(
attribute_protos.emplace_back(ONNX_NAMESPACE::MakeAttribute("beta", beta));
attribute_protos.emplace_back(ONNX_NAMESPACE::MakeAttribute("lambda", lambda));
attribute_protos.emplace_back(ONNX_NAMESPACE::MakeAttribute("epsilon", epsilon));
attribute_protos.emplace_back(ONNX_NAMESPACE::MakeAttribute("max_norm_clip", max_norm_clip));
attribute_protos.emplace_back(ONNX_NAMESPACE::MakeAttribute("ratio_min", ratio_min));
attribute_protos.emplace_back(ONNX_NAMESPACE::MakeAttribute("ratio_max", ratio_max));
attribute_protos.emplace_back(ONNX_NAMESPACE::MakeAttribute("do_bias_correction", do_bias_correction));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class LambOptimizerBuilder final : public OptimizerBuilder {
"beta",
"lambda",
"epsilon",
"max_norm_clip",
"ratio_min",
"ratio_max",
"do_bias_correction"}) {}
Expand Down
10 changes: 5 additions & 5 deletions orttraining/orttraining/core/graph/optimizer_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ Status IsMatchingTypeAndShape(
const int32_t element_type,
const std::vector<int64_t>& expected_shape);

/**
/**
* The configuration for optimizer builder.
*/
struct OptimizerBuilderConfig{
struct OptimizerBuilderConfig {
//The ArgDefs of the weights to optimize.
std::vector<ArgDef> weight_argdefs;

Expand All @@ -70,11 +70,11 @@ struct OptimizerBuilderConfig{
// The per weight optimizer configuration.
std::vector<OptimizerNodeConfig> opt_configs;

// (Optional) The flag to force gradient clipping. If planning
// to use the default behavior of each sub-class, should not be set.
// (Optional) The flag to force gradient clipping. If planning
// to use the default behavior of each sub-class, should not be set.
optional<bool> enable_grad_clipping;

// The initial state for optimizer params
// The initial state for optimizer params
// shared by all weights.
NameMLValMap shared_optimizer_states{};
};
Expand Down
16 changes: 10 additions & 6 deletions orttraining/orttraining/core/graph/optimizer_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ struct OptimizerNodeConfig {
std::unordered_map<std::string, float> attributes{};
std::unordered_map<std::string, int64_t> int_attributes{};
std::string loss_scale_input_name{};
NameMLValMap initial_states{}; // initial states for optimizer initializers
NameMLValMap initial_states{}; // initial states for optimizer initializers
bool use_mixed_precision_moments{false};
bool update_weight{true}; // indicates whether Optimizer should do weight update, or output new gradient
bool enabled{true}; // indicates whether this weight is included in the Optimizer
Expand All @@ -71,18 +71,22 @@ struct OptimizerGraphConfig {
std::string loss_scale_input_name{}; // empty string means no loss scaling factor is applied
AdasumReductionType adasum_reduction_type{AdasumReductionType::None};
bool enable_grad_norm_clip{true};
NameMLValMap shared_optimizer_states{}; // initial states for shared params, eg. 'Step' for lamb

NameMLValMap shared_optimizer_states{}; // initial states for shared params, eg. 'Step' for lamb

ONNX_NAMESPACE::TensorProto_DataType AllReduceDataType() const {
if (!allreduce_in_mixed_precision_type) {
return ONNX_NAMESPACE::TensorProto_DataType_FLOAT;
}

switch (mixed_precision_type) {
case MixedPrecisionDataType::FP16: return ONNX_NAMESPACE::TensorProto_DataType_FLOAT16;
case MixedPrecisionDataType::BF16: return ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16;
default: return ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED;
}
case MixedPrecisionDataType::FP16:
return ONNX_NAMESPACE::TensorProto_DataType_FLOAT16;
case MixedPrecisionDataType::BF16:
return ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16;
default:
return ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED;
}
}
};

Expand Down
35 changes: 24 additions & 11 deletions orttraining/orttraining/core/graph/training_op_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,11 @@ OpSchema& RegisterLambOpSchema(OpSchema&& op_schema) {
"Small scalar to avoid dividing by zero.",
AttributeProto::FLOATS,
std::vector<float>(1024, 1e-6f))
.Attr(
"max_norm_clip",
"clip threshold of gradients.",
AttributeProto::FLOATS,
std::vector<float>(1024, 1.f))
.Attr(
"do_bias_correction",
"Compute unbiased 1st and 2nd momentums.",
Expand Down Expand Up @@ -663,6 +668,11 @@ void RegisterTrainingOpSchemas() {
"Small scalar to avoid dividing by zero.",
AttributeProto::FLOAT,
1e-8f)
.Attr(
"max_norm_clip",
"clip threshold of gradients.",
AttributeProto::FLOAT,
1.0f)
.Attr(
"do_bias_correction",
"Compute unbiased 1st and 2nd momentums.",
Expand Down Expand Up @@ -926,9 +936,10 @@ Example 4:
ONNX_CONTRIB_OPERATOR_SCHEMA(NcclAllReduce)
.SetDomain(kMSDomain)
.SinceVersion(1)
.Attr("group_type", "0 - global parallel group, 1 - data parallel group, "
"2 - node local data parallel group, 3 - cross node data parallel group, "
"4 - horozontal parallel, 5 - model parallel.",
.Attr("group_type",
"0 - global parallel group, 1 - data parallel group, "
"2 - node local data parallel group, 3 - cross node data parallel group, "
"4 - horozontal parallel, 5 - model parallel.",
AttributeProto::INT,
static_cast<int64_t>(0))
.Input(0, "input", "tensors to be reduced", "T", OpSchema::Variadic)
Expand All @@ -945,9 +956,10 @@ Example 4:
ONNX_CONTRIB_OPERATOR_SCHEMA(NcclAllGather)
.SetDomain(kMSDomain)
.SinceVersion(1)
.Attr("group_type", "0 - global parallel group, 1 - data parallel group, "
"2 - node local data parallel group, 3 - cross node data parallel group, "
"4 - horozontal parallel, 5 - model parallel.",
.Attr("group_type",
"0 - global parallel group, 1 - data parallel group, "
"2 - node local data parallel group, 3 - cross node data parallel group, "
"4 - horozontal parallel, 5 - model parallel.",
AttributeProto::INT,
static_cast<int64_t>(0))
.Input(0, "input", "tensors to be sent", "T", OpSchema::Variadic)
Expand All @@ -964,9 +976,10 @@ Example 4:
ONNX_CONTRIB_OPERATOR_SCHEMA(NcclReduceScatter)
.SetDomain(kMSDomain)
.SinceVersion(1)
.Attr("group_type", "0 - global parallel group, 1 - data parallel group, "
"2 - node local data parallel group, 3 - cross node data parallel group, "
"4 - horozontal parallel, 5 - model parallel.",
.Attr("group_type",
"0 - global parallel group, 1 - data parallel group, "
"2 - node local data parallel group, 3 - cross node data parallel group, "
"4 - horozontal parallel, 5 - model parallel.",
AttributeProto::INT,
static_cast<int64_t>(0))
.Input(0, "input", "tensors to be reduced and scattered", "T", OpSchema::Variadic)
Expand All @@ -980,7 +993,7 @@ Example 4:
assert(getAttribute(ctx, "group_type", 0) < static_cast<int64_t>(WorkerGroupType::WorkerGroupTypeCount));
})
#endif
;
;

ONNX_CONTRIB_OPERATOR_SCHEMA(AdasumAllReduce)
.SetDomain(kMSDomain)
Expand All @@ -1002,7 +1015,7 @@ Example 4:
propagateElemTypeFromInputToOutput(ctx, i, i);
auto typeProto = ctx.getInputType(i);
if (!hasShape(*typeProto)) {
continue;
continue;
}
propagateShapeFromInputToOutput(ctx, i, i);
}
Expand Down
29 changes: 14 additions & 15 deletions orttraining/orttraining/core/session/training_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ Status TrainingSession::ConfigureForTraining(
std::string loss_name{};

if (config.pipeline_config.has_value()) {
// if use pipeline, first check if model contains send op. If it does, set the
// If use pipeline, first check if model contains send op. If it does, set the
// send node's output as the start tensor to build gradient graph
GetPipelineSendOutput(model_->MainGraph(), loss_name);
}
Expand Down Expand Up @@ -425,14 +425,14 @@ Status TrainingSession::ConfigureForTraining(
ORT_RETURN_IF_ERROR(ApplyModelParallelTransformationsToMainGraph(trainable_initializers, config_result));

weight_partition_info_ = config_result.weight_partition_info;

// Save the model after graph transformations
if (IsRootNode(config) && config.model_after_graph_transforms_path.has_value()) {
ORT_IGNORE_RETURN_VALUE(Save(
config.model_after_graph_transforms_path.value(), SaveOption::NO_RELOAD));
}

// derive actual set of weights to train
// Derive actual set of weights to train
std::unordered_set<std::string> weight_names_to_train =
!filtered_config_weight_names_to_train.empty()
? trainable_initializers
Expand Down Expand Up @@ -467,7 +467,7 @@ Status TrainingSession::ConfigureForTraining(

ORT_RETURN_IF_ERROR(BuildGradientGraph(
weight_names_to_train, loss_name, config.gradient_graph_config, *session_logger_));

if (IsRootNode(config) && config.model_with_gradient_graph_path.has_value()) {
ORT_IGNORE_RETURN_VALUE(Save(
config.model_with_gradient_graph_path.value(), SaveOption::NO_RELOAD));
Expand Down Expand Up @@ -495,7 +495,7 @@ Status TrainingSession::ConfigureForTraining(
}
}

// add optimizer or gradient accumulation
// Add optimizer or gradient accumulation
if (config.optimizer_config.has_value()) {
OptimizerGraphConfig opt_graph_config{};
std::unordered_map<std::string, OptimizerNodeConfig> opt_node_configs{};
Expand All @@ -516,7 +516,7 @@ Status TrainingSession::ConfigureForTraining(
// Set eval feed names for nodes that differ between training and inferencing.
ORT_RETURN_IF_ERROR(SetEvalFeedNames());

// add Tensorboard
// Add Tensorboard
if (config.tensorboard_config.has_value()) {
const auto& tensorboard_config = config.tensorboard_config.value();

Expand All @@ -526,7 +526,7 @@ Status TrainingSession::ConfigureForTraining(
tensorboard_scalar_names.emplace_back(loss_scale_input_name.value());
}

// add some tensors from optimizer graph outputs
// Add some tensors from optimizer graph outputs
if (config_result.opt_config_result.has_value()) {
const auto& opt_output_key_to_graph_output_name =
config_result.opt_config_result.value().output_key_to_graph_output_name;
Expand All @@ -549,7 +549,7 @@ Status TrainingSession::ConfigureForTraining(
tensorboard_config.dump_convergence_metrics));
}

// add GIST encoding
// Add GIST encoding
if (config.gist_config.has_value()) {
ORT_RETURN_IF_ERROR(AddGistEncoding());
}
Expand Down Expand Up @@ -595,7 +595,7 @@ static Status AddLossScaling(
return Status::OK();
}

// add node to scale loss_name by loss_scale_input_name
// Add node to scale loss_name by loss_scale_input_name
GraphAugmenter::GraphDefs defs{};
*loss_scale_input_name = graph.GenerateNodeArgName("loss_scale");
const auto* loss_scale_input_type =
Expand All @@ -621,7 +621,7 @@ static Status ConfigureLossFunctionInternal(
Graph& graph,
std::string* loss_scale_input_name,
std::string& actual_loss_name) {
// build loss function or use external one
// Build loss function or use external one
ORT_RETURN_IF_NOT(
(loss_func_info.has_value() && loss_graph_builder) ^ external_loss_name.has_value(),
"Either loss function information should be provided or an external "
Expand Down Expand Up @@ -808,7 +808,6 @@ Status TrainingSession::ApplyModelParallelTransformationsToMainGraph(std::unorde
graph_transformation_mgr.Register(std::move(entry), TransformerLevel::Level1);
}

// apply transformers
Graph& graph = model_->MainGraph();
ORT_RETURN_IF_ERROR(graph_transformation_mgr.ApplyTransformers(
graph, TransformerLevel::Level1, *session_logger_));
Expand Down Expand Up @@ -1688,10 +1687,10 @@ Status PipelineTrainingSession::BuildLossAndLossScaling(
loss_scale_input_name = enable_true_loss_scale ? optional<std::string>{""} : optional<std::string>{};

ORT_RETURN_IF_ERROR(BuildLoss(
external_loss_name,
loss_name,
loss_function_config,
loss_scale_input_name));
external_loss_name,
loss_name,
loss_function_config,
loss_scale_input_name));

if (enable_true_loss_scale) {
TrainingConfigurationResult::MixedPrecisionConfigurationResult mp_result{};
Expand Down
Loading

0 comments on commit 453431f

Please sign in to comment.