From 0f6c888d2a4bb8f5d71b2fb90b6e10bd9294d90f Mon Sep 17 00:00:00 2001 From: Xueyun Zhu Date: Wed, 22 Apr 2020 19:37:11 +0000 Subject: [PATCH] address feedback and bug fix --- .../core/graph/pipeline_transformer.cc | 5 +- .../test/graph/gradient_graph_builder_test.cc | 75 ++++++++++++++++++- .../tools/scripts/pipeline_model_split.py | 17 ++++- 3 files changed, 88 insertions(+), 9 deletions(-) diff --git a/orttraining/orttraining/core/graph/pipeline_transformer.cc b/orttraining/orttraining/core/graph/pipeline_transformer.cc index 9b1df0469d20d..39700c6e9bddb 100644 --- a/orttraining/orttraining/core/graph/pipeline_transformer.cc +++ b/orttraining/orttraining/core/graph/pipeline_transformer.cc @@ -166,6 +166,7 @@ Status AddOrSkipRecordForwardWaitBackward(Graph& graph, Node* send_fw, Node* rec // if we have a send forward op followed by a recv backward op, insert WaitEvent and RecordEvent in between. Node* record_node = nullptr; + Node* wait_node = nullptr; // Insert RecordEvent { @@ -204,13 +205,15 @@ Status AddOrSkipRecordForwardWaitBackward(Graph& graph, Node* send_fw, Node* rec output_args.push_back(&new_output); input = &new_output; - graph.AddNode(graph.GenerateNodeName("WaitEvent"), + auto& new_node = graph.AddNode(graph.GenerateNodeName("WaitEvent"), "WaitEvent", "Backward pass", input_args, output_args, /* output */ {}, /* attribute */ kMSDomain); + wait_node = &new_node; + ORT_UNUSED_PARAMETER(wait_node); } return Status::OK(); diff --git a/orttraining/orttraining/test/graph/gradient_graph_builder_test.cc b/orttraining/orttraining/test/graph/gradient_graph_builder_test.cc index 9a49fd535adf1..b7a9ce0206707 100644 --- a/orttraining/orttraining/test/graph/gradient_graph_builder_test.cc +++ b/orttraining/orttraining/test/graph/gradient_graph_builder_test.cc @@ -1000,9 +1000,9 @@ class PipelineBatchPlanner { // verify pipeline config can load and gradient graph can construct. TEST(GradientGraphBuilderTest, TrainingSession_PipelineTransform_base) { - std::string filename_base = PIPELINE_MODEL_BASE; + PathString filename_base = PIPELINE_MODEL_BASE; - auto load_gradient_graph = [](std::string& filename) { + auto load_gradient_graph = [](int stageIdx, std::string& filename) { auto config = MakeBasicTrainingConfig(); config.use_pipeline = true; @@ -1013,6 +1013,73 @@ TEST(GradientGraphBuilderTest, TrainingSession_PipelineTransform_base) { std::shared_ptr model; ASSERT_TRUE(Model::Load(backprop_model_file, model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); + Graph& graph = model->MainGraph(); + auto is_backward = [](Node& node) { + return (node.Description() == "Backward pass"); + }; + // check for wait/record node + Node* wait_fw{nullptr}; + Node* wait_bw{nullptr}; + Node* record_fw{nullptr}; + Node* record_bw{nullptr}; + for (auto& node : graph.Nodes()) { + if (node.OpType() == "WaitEvent") { + if (is_backward(node)) { + wait_bw = &node; + } else { + wait_fw = &node; + } + } else if (node.OpType() == "RecordEvent") { + if (is_backward(node)) { + record_bw = &node; + } else { + record_fw = &node; + } + } + } + // every partition should have wait forward and record backward + ASSERT_TRUE(wait_fw && record_bw); + if (stageIdx == 2) { + // the last partition can perform back prop right away. It won't have record + // forward and wait backward + ASSERT_TRUE(!record_fw && !wait_bw); + } else { + ASSERT_TRUE(record_fw && wait_bw); + } + + // check for send/recv node + Node* send_fw{nullptr}; + Node* send_bw{nullptr}; + Node* recv_fw{nullptr}; + Node* recv_bw{nullptr}; + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Send") { + if (is_backward(node)) { + send_bw = &node; + } else { + send_fw = &node; + } + } else if (node.OpType() == "Recv") { + if (is_backward(node)) { + recv_bw = &node; + } else { + recv_fw = &node; + } + } + } + // except the last partion, each partition should have send forward and recv backward + if (stageIdx == 0 || stageIdx == 1) { + ASSERT_TRUE(send_fw && recv_bw); + } else { + ASSERT_TRUE(!send_fw && !recv_bw); + } + // except the first partion, each partition should have recv forward and send backward + if (stageIdx == 1 || stageIdx == 2) { + ASSERT_TRUE(recv_fw && send_bw); + } else { + ASSERT_TRUE(!recv_fw && !send_bw); + } + auto mp = model->ToProto(); std::ofstream ofs(filename + "_back.onnx", std::ofstream::binary); mp.SerializeToOstream(&ofs); @@ -1020,8 +1087,8 @@ TEST(GradientGraphBuilderTest, TrainingSession_PipelineTransform_base) { }; for (int i = 0; i < 3; ++i) { - std::string name = filename_base + std::to_string(i); - load_gradient_graph(name); + PathString name = filename_base + ORT_TSTR(std::to_string(i)); + load_gradient_graph(i, name); } } diff --git a/orttraining/tools/scripts/pipeline_model_split.py b/orttraining/tools/scripts/pipeline_model_split.py index e0e774d862700..008e626e3257d 100644 --- a/orttraining/tools/scripts/pipeline_model_split.py +++ b/orttraining/tools/scripts/pipeline_model_split.py @@ -46,9 +46,9 @@ def split_graph(model, split_edge_groups): upstream_nodes_output_index.append(i) # assuming all tensors are of type float element_types.append(1) - for info in model.graph.value_info: - if info.name == id: - output_shapes.append(info.type) + for info in model.graph.value_info: + if info.name == id: + output_shapes.append(info.type) send_input_signal_name = 'send_input_signal' + str(cut_index) send_signal = model.graph.input.add() @@ -227,6 +227,7 @@ def insert_identity(model, all_cut_inputs): count += 1 split_edges.append(new_edge_name) updated_edges[i.edgeId] = new_edge_name + need_shape_inference = True else: split_edges.append(i.edgeId) split_edge_groups.append(split_edges) @@ -257,6 +258,14 @@ def get_index(node_list, node): found = [i for i, n in enumerate(node_list) if n == node] return found[0] if found else None +def get_identity_index_for_deleting(node_list, node): + for i, n in enumerate(node_list): + # The node's input name has been changed during send/recv insertion, + # but it is sufficient to just compare the type and outputs. + if (n.op_type == 'Identity' and n.output == node.output): + return i + return None + # traverse the graph, group connected nodes and generate subgraph @@ -269,7 +278,7 @@ def generate_subgraph(model, start_nodes, identity_node_list): # remove added identity node before copy to subgraph identity_node_index = [] for n in identity_node_list: - identity_node_index.append(get_index(main_graph.graph.node, n)) + identity_node_index.append(get_identity_index_for_deleting(main_graph.graph.node, n)) identity_node_index.sort(reverse=True) for i in reversed(range(len(main_graph.graph.node))):