Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SoftmaxCrossEntropyLoss-12 forward and backward kernel implementation. #3465

Merged
merged 12 commits into from
Apr 16, 2020
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ static std::unordered_map<std::string, std::unordered_set<size_t>>
{"Dropout", {1}},
{"Slice", {1, 2, 3, 4}},
{"SparseSoftmaxCrossEntropy", {1, 2}},
{"SoftmaxCrossEntropyLoss", {1, 2}},
{"ConstantOfShape", {0}},
{"Scatter", {1}},
{"OneHot", {0, 1, 2}},
Expand Down
18 changes: 18 additions & 0 deletions orttraining/orttraining/core/graph/gradient_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -888,6 +888,24 @@ IMPLEMENT_GRADIENT_BUILDER(GetSparseSoftmaxCrossEntropyGradient) {
}
}

IMPLEMENT_GRADIENT_BUILDER(GetSoftmaxCrossEntropyLossGradient) {
if (GetSrcNodeInputSize() == 2) {
return std::vector<NodeDef>{
NodeDef(OpDef{"SoftmaxCrossEntropyLossGrad", kMSDomain, 1},
{GO(0), O(1), I(1)},
{GI(0)},
SrcNodeAttributes())};
} else if (GetSrcNodeInputSize() == 3) {
return std::vector<NodeDef>{
NodeDef(OpDef{"SoftmaxCrossEntropyLossGrad", kMSDomain, 1},
{GO(0), O(1), I(1), I(2)},
{GI(0)},
SrcNodeAttributes())};
} else {
ORT_ENFORCE(false, "the number of input arguments must be 2 or 3");
codemzs marked this conversation as resolved.
Show resolved Hide resolved
}
}

IMPLEMENT_GRADIENT_BUILDER(GetGlobalAveragePoolGradient) {
const ArgDef& X = I(0);

Expand Down
1 change: 1 addition & 0 deletions orttraining/orttraining/core/graph/gradient_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ DECLARE_GRADIENT_BUILDER(GetSqueezeGradient)
DECLARE_GRADIENT_BUILDER(GetSoftmaxGradient)
DECLARE_GRADIENT_BUILDER(GetSoftmaxCrossEntropyGradient)
DECLARE_GRADIENT_BUILDER(GetSparseSoftmaxCrossEntropyGradient)
DECLARE_GRADIENT_BUILDER(GetSoftmaxCrossEntropyLossGradient)
DECLARE_GRADIENT_BUILDER(GetGlobalAveragePoolGradient)
DECLARE_GRADIENT_BUILDER(GetGemmGradient)
DECLARE_GRADIENT_BUILDER(GetDropoutGradient)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
REGISTER_GRADIENT_BUILDER("Softmax", GetSoftmaxGradient);
REGISTER_GRADIENT_BUILDER("SoftmaxCrossEntropy", GetSoftmaxCrossEntropyGradient);
REGISTER_GRADIENT_BUILDER("SparseSoftmaxCrossEntropy", GetSparseSoftmaxCrossEntropyGradient);
REGISTER_GRADIENT_BUILDER("SoftmaxCrossEntropyLoss", GetSoftmaxCrossEntropyLossGradient);
REGISTER_GRADIENT_BUILDER("GlobalAveragePool", GetGlobalAveragePoolGradient);
REGISTER_GRADIENT_BUILDER("AveragePool", GetAveragePoolGradient);
REGISTER_GRADIENT_BUILDER("Dropout", GetDropoutGradient)
Expand Down
29 changes: 29 additions & 0 deletions orttraining/orttraining/core/graph/gradient_schema_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,35 @@ void RegisterGradientSchemas() {
"Constrain indices to integer types")
.SetDoc(R"DOC(SparseSoftmaxCrossEntropyGrad)DOC");

ONNX_CONTRIB_OPERATOR_SCHEMA(SoftmaxCrossEntropyLossGrad)
.SetDomain(kMSDomain)
.SinceVersion(1)
.Attr("reduction",
reduction_doc,
AttributeProto::STRING,
std::string("mean"))
.Attr(
"ignore_index",
"Specifies a target value that is ignored and does not contribute to the input gradient.",
AttributeProto::INT,
false)
.Input(0, "dY", "gradient of Y", "T")
.Input(1, "log_prob", "logsoftmax(logits), (N+1)-D input of shape (batch_size).", "T")
.Input(2, "label",
"label is N-D input whose shape should match that of logits. "
"It is a tensor of nonnegative integers, "
"where each element is the nonnegative integer label for the element of the batch.",
"Tind")
.Input(3, "weight", "weight for each sample. The shape is 1-D tensor.", "T", OpSchema::Optional)
.Output(0, "d_logits", "gradient of logits", "T")
.TypeConstraint("T",
{"tensor(float16)", "tensor(float)", "tensor(double)"},
"Constrain to float, float16 and double tensors.")
.TypeConstraint("Tind",
{"tensor(int32)", "tensor(int64)"},
"Constrain indices to integer types")
.SetDoc(R"DOC(SoftmaxCrossEntropyLossGrad)DOC");

ONNX_CONTRIB_OPERATOR_SCHEMA(TrainableDropout)
.SetDomain(kOnnxDomain)
.SinceVersion(9)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,5 +92,58 @@ GraphAugmenter::GraphDefs SparseSoftmaxCrossEntropy::operator()(
return graph_defs;
}

GraphAugmenter::GraphDefs SoftmaxCrossEntropyLoss::operator()(
const Graph& graph,
const LossFunctionInfo& loss_func_info) {
const VectorString& args = loss_func_info.loss_builder_args;
ORT_ENFORCE(args.size() == 2 || args.size() == 3, " Invalid loss_func_info for SoftmaxCrossEntropyLoss.");
const std::string& prediction_name = args[0];
const std::string& label_name = args[1];
const std::string& loss_name = loss_func_info.loss_name;
const std::string& prob_name = prediction_name + "_probability";

GraphAugmenter::GraphDefs graph_defs;
graph_defs.AddGraphOutputs({loss_name});
std::vector<NodeDef> new_nodes;

{
const NodeArg* prediction_arg = graph.GetNodeArg(prediction_name);
ORT_ENFORCE(prediction_arg != nullptr, "Prediction arg ", prediction_name, " is not found in the graph.");
TypeProto* label_type_proto = GetSparseTypeProto(prediction_arg,
ONNX_NAMESPACE::TensorProto_DataType_INT64,
graph_defs);

if (args.size() == 3) {
const std::string& weight_name = args[2];
TypeProto* weight_type_proto = graph_defs.CreateTypeProto();
weight_type_proto->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
weight_type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->CopyFrom(
prediction_arg->TypeAsProto()->tensor_type().shape().dim()[1]);

new_nodes.emplace_back(NodeDef("SoftmaxCrossEntropyLoss", // Op
{ArgDef(prediction_name),
ArgDef(label_name, label_type_proto),
ArgDef(weight_name, weight_type_proto)}, // Inputs
{ArgDef(loss_name),
ArgDef(prob_name, prediction_arg->TypeAsProto())}, // Outputs
NodeAttributes(),
"SoftmaxCrossEntropy" // name
));
} else {
new_nodes.emplace_back(NodeDef("SoftmaxCrossEntropyLoss", // Op
{ArgDef(prediction_name),
ArgDef(label_name, label_type_proto)}, // Inputs
{ArgDef(loss_name),
ArgDef(prob_name, prediction_arg->TypeAsProto())}, // Outputs
NodeAttributes(),
"SoftmaxCrossEntropy" // name
));
}
}

graph_defs.AddNodeDefs(new_nodes);
return graph_defs;
}

} // namespace training
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,9 @@ struct SparseSoftmaxCrossEntropy : public ILossFunction {
GraphAugmenter::GraphDefs operator()(const Graph& graph, const LossFunctionInfo& loss_func_info) override;
};

struct SoftmaxCrossEntropyLoss : public ILossFunction {
GraphAugmenter::GraphDefs operator()(const Graph& graph, const LossFunctionInfo& loss_func_info) override;
};

} // namespace training
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ void LossFunctionRegistry::RegisterNonOperatorLossFunctions() {
REGISTER_NON_OPERATOR_LOSS_FUNCTION(BertLoss);
REGISTER_NON_OPERATOR_LOSS_FUNCTION(SoftmaxCrossEntropy);
REGISTER_NON_OPERATOR_LOSS_FUNCTION(SparseSoftmaxCrossEntropy);
REGISTER_NON_OPERATOR_LOSS_FUNCTION(SoftmaxCrossEntropyLoss);
}
} // namespace training
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(T
rule_transformer->Register(make_unique<InsertMaxPoolOutput>());
rule_transformer->Register(make_unique<AdjustBatchNormOutputs>());
rule_transformer->Register(make_unique<UnsqueezeElimination>());
rule_transformer->Register(make_unique<InsertSoftmaxCrossEntropyLossOutput>());

transformers.emplace_back(onnxruntime::make_unique<GeluFusion>(compatible_eps));
transformers.emplace_back(onnxruntime::make_unique<LayerNormFusion>(compatible_eps));
Expand Down
25 changes: 25 additions & 0 deletions orttraining/orttraining/core/optimizer/insert_output_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,31 @@ bool InsertMaxPoolOutput::SatisfyCondition(const Graph& /*graph*/, const Node& n
return false;
}

Status InsertSoftmaxCrossEntropyLossOutput::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& /*logger*/) const {
auto& outputs = node.MutableOutputDefs();
auto& inputs = node.MutableInputDefs();
codemzs marked this conversation as resolved.
Show resolved Hide resolved
const NodeArg* X = inputs[0];

TypeProto t;
t.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
SherlockNoMad marked this conversation as resolved.
Show resolved Hide resolved
t.mutable_tensor_type()->mutable_shape()->CopyFrom(*X->Shape()); // log probability should have the same shape as logits.

NodeArg& node_arg = graph.GetOrCreateNodeArg(X->Name() + "_log_prob", &t);

outputs.push_back(&node_arg);

rule_effect = RewriteRuleEffect::kUpdatedCurrentNode;
return Status::OK();
}

bool InsertSoftmaxCrossEntropyLossOutput::SatisfyCondition(const Graph& /*graph*/, const Node& node, const logging::Logger& /*logger*/) const {
if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "SoftmaxCrossEntropyLoss", {12}) &&
node.OutputDefs().size() == 1) {
return true;
}
return false;
}

Status AdjustBatchNormOutputs::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& /*logger*/) const {
auto& outputs = node.MutableOutputDefs();
const auto& inputs = node.InputDefs();
Expand Down
17 changes: 17 additions & 0 deletions orttraining/orttraining/core/optimizer/insert_output_rewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,23 @@ class InsertMaxPoolOutput : public RewriteRule {
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
};

// Rewrite rule that insert an addtional output to the matched node.
SherlockNoMad marked this conversation as resolved.
Show resolved Hide resolved
class InsertSoftmaxCrossEntropyLossOutput : public RewriteRule {
public:
InsertSoftmaxCrossEntropyLossOutput() noexcept
: RewriteRule("InsertSoftmaxCrossEntropyLossOutput") {
}

std::vector<std::string> TargetOpTypes() const noexcept override {
return {};
SherlockNoMad marked this conversation as resolved.
Show resolved Hide resolved
}

private:
bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override;

Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
};

// Rewrite rule that adjust Batch Normalization nodes to have 5 outputs for training mode
// instead of 1 for inference mode
class AdjustBatchNormOutputs : public RewriteRule {
Expand Down
102 changes: 102 additions & 0 deletions orttraining/orttraining/test/gradient/gradient_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1207,6 +1207,108 @@ TEST(GradientCheckerTest, SparseSoftmaxCrossEntropyGrad) {
TestSparseSoftmaxCrossEntropyGrad({2, 3, 2}, "sum");
}

void TestSoftmaxCrossEntropyLossGrad(const TensorShape& index_shape, const std::string& reduction,
SherlockNoMad marked this conversation as resolved.
Show resolved Hide resolved
int64_t ignore_index = 0, int64_t D = 2) {
float max_error;
GradientChecker<float, float, float> gradient_checker;
OpDef op_def{"SoftmaxCrossEntropyLoss", kOnnxDomain, 12};
std::function<float(float)> transformer_index = [D](float x) { return std::fmod(std::fabs(x) * 5.0f, D * 1.0f); };
codemzs marked this conversation as resolved.
Show resolved Hide resolved
std::function<float(float)> transformer_weight = [](float x) { return std::fmod(std::fabs(x), 2.0f); };

// without weight and ignore_index
{
std::vector<int64_t> logit_shape(index_shape.GetDims());
SherlockNoMad marked this conversation as resolved.
Show resolved Hide resolved
auto it = logit_shape.begin() + 1;
logit_shape.insert(it, D);
TensorInfo loss_info = {};
if (reduction == "none") {
loss_info = {TensorInfo(index_shape.GetDims())};
}

TensorInfo x_info(logit_shape);
TensorInfo index_info(index_shape, false, &transformer_index, DataTypeImpl::GetTensorType<int64_t>());

gradient_checker.ComputeGradientError(op_def, {x_info, index_info},
{loss_info, {logit_shape, false}}, &max_error,
{MakeAttribute("reduction", reduction)});
EXPECT_IS_TINY(max_error);
}

// with weight and no ignore_index
{
std::vector<int64_t> logit_shape(index_shape.GetDims());
auto it = logit_shape.begin() + 1;
logit_shape.insert(it, D);
TensorInfo loss_info = {};
if (reduction == "none") {
loss_info = {TensorInfo(index_shape.GetDims())};
}

TensorInfo x_info(logit_shape);
TensorInfo index_info(index_shape, false, &transformer_index, DataTypeImpl::GetTensorType<int64_t>());
TensorInfo weight_info({logit_shape[1]}, false, &transformer_weight);

gradient_checker.ComputeGradientError(op_def, {x_info, index_info, weight_info},
{loss_info, {logit_shape, false}}, &max_error,
{MakeAttribute("reduction", reduction)});
EXPECT_IS_TINY(max_error);
}

// without weight and ignore index
{
std::vector<int64_t> logit_shape(index_shape.GetDims());
auto it = logit_shape.begin() + 1;
logit_shape.insert(it, D);
TensorInfo loss_info = {};
if (reduction == "none") {
loss_info = {TensorInfo(index_shape.GetDims())};
}

TensorInfo x_info(logit_shape);
TensorInfo index_info(index_shape, false, &transformer_index, DataTypeImpl::GetTensorType<int64_t>());

gradient_checker.ComputeGradientError(op_def, {x_info, index_info},
{loss_info, {logit_shape, false}}, &max_error,
{MakeAttribute("reduction", reduction), MakeAttribute("ignore_index", ignore_index)});
EXPECT_IS_TINY(max_error);
}

// with weight and ignore_index
{
std::vector<int64_t> logit_shape(index_shape.GetDims());
auto it = logit_shape.begin() + 1;
logit_shape.insert(it, D);
TensorInfo loss_info = {};
if (reduction == "none") {
loss_info = {TensorInfo(index_shape.GetDims())};
}

TensorInfo x_info(logit_shape);
TensorInfo index_info(index_shape, false, &transformer_index, DataTypeImpl::GetTensorType<int64_t>());
TensorInfo weight_info({logit_shape[1]}, false, &transformer_weight);

gradient_checker.ComputeGradientError(op_def, {x_info, index_info, weight_info},
{loss_info, {logit_shape, false}}, &max_error,
{MakeAttribute("reduction", reduction), MakeAttribute("ignore_index", ignore_index)});
EXPECT_IS_TINY(max_error);
}
}

TEST(GradientCheckerTest, SoftmaxCrossEntropyLossGrad) {
TestSoftmaxCrossEntropyLossGrad({5}, "mean");
TestSoftmaxCrossEntropyLossGrad({5}, "sum");
TestSoftmaxCrossEntropyLossGrad({2}, "none");
TestSoftmaxCrossEntropyLossGrad({2, 3, 2}, "mean");
TestSoftmaxCrossEntropyLossGrad({2, 3, 2}, "sum");
TestSoftmaxCrossEntropyLossGrad({2, 3, 2}, "none");
TestSoftmaxCrossEntropyLossGrad({5}, "mean", -1);
TestSoftmaxCrossEntropyLossGrad({5}, "sum", -1);
TestSoftmaxCrossEntropyLossGrad({2}, "none", -1);
TestSoftmaxCrossEntropyLossGrad({2, 3, 2}, "mean", -1);
TestSoftmaxCrossEntropyLossGrad({2, 3, 2}, "sum", -1);
TestSoftmaxCrossEntropyLossGrad({2, 3, 2}, "none", -1);
SherlockNoMad marked this conversation as resolved.
Show resolved Hide resolved
}

TEST(GradientCheckerTest, GeluGrad) {
UnaryOpGradientTest("Gelu", kMSDomain, 1);
}
Expand Down
Loading