Skip to content

Commit

Permalink
address feedback and bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
xzhu1900 committed Apr 22, 2020
1 parent 87c1cfc commit 0f6c888
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 9 deletions.
5 changes: 4 additions & 1 deletion orttraining/orttraining/core/graph/pipeline_transformer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ Status AddOrSkipRecordForwardWaitBackward(Graph& graph, Node* send_fw, Node* rec

// if we have a send forward op followed by a recv backward op, insert WaitEvent and RecordEvent in between.
Node* record_node = nullptr;
Node* wait_node = nullptr;

// Insert RecordEvent
{
Expand Down Expand Up @@ -204,13 +205,15 @@ Status AddOrSkipRecordForwardWaitBackward(Graph& graph, Node* send_fw, Node* rec
output_args.push_back(&new_output);
input = &new_output;

graph.AddNode(graph.GenerateNodeName("WaitEvent"),
auto& new_node = graph.AddNode(graph.GenerateNodeName("WaitEvent"),
"WaitEvent",
"Backward pass",
input_args,
output_args, /* output */
{}, /* attribute */
kMSDomain);
wait_node = &new_node;
ORT_UNUSED_PARAMETER(wait_node);
}

return Status::OK();
Expand Down
75 changes: 71 additions & 4 deletions orttraining/orttraining/test/graph/gradient_graph_builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1000,9 +1000,9 @@ class PipelineBatchPlanner {

// verify pipeline config can load and gradient graph can construct.
TEST(GradientGraphBuilderTest, TrainingSession_PipelineTransform_base) {
std::string filename_base = PIPELINE_MODEL_BASE;
PathString filename_base = PIPELINE_MODEL_BASE;

auto load_gradient_graph = [](std::string& filename) {
auto load_gradient_graph = [](int stageIdx, std::string& filename) {
auto config = MakeBasicTrainingConfig();

config.use_pipeline = true;
Expand All @@ -1013,15 +1013,82 @@ TEST(GradientGraphBuilderTest, TrainingSession_PipelineTransform_base) {
std::shared_ptr<Model> model;
ASSERT_TRUE(Model::Load(backprop_model_file, model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK());

Graph& graph = model->MainGraph();
auto is_backward = [](Node& node) {
return (node.Description() == "Backward pass");
};
// check for wait/record node
Node* wait_fw{nullptr};
Node* wait_bw{nullptr};
Node* record_fw{nullptr};
Node* record_bw{nullptr};
for (auto& node : graph.Nodes()) {
if (node.OpType() == "WaitEvent") {
if (is_backward(node)) {
wait_bw = &node;
} else {
wait_fw = &node;
}
} else if (node.OpType() == "RecordEvent") {
if (is_backward(node)) {
record_bw = &node;
} else {
record_fw = &node;
}
}
}
// every partition should have wait forward and record backward
ASSERT_TRUE(wait_fw && record_bw);
if (stageIdx == 2) {
// the last partition can perform back prop right away. It won't have record
// forward and wait backward
ASSERT_TRUE(!record_fw && !wait_bw);
} else {
ASSERT_TRUE(record_fw && wait_bw);
}

// check for send/recv node
Node* send_fw{nullptr};
Node* send_bw{nullptr};
Node* recv_fw{nullptr};
Node* recv_bw{nullptr};
for (auto& node : graph.Nodes()) {
if (node.OpType() == "Send") {
if (is_backward(node)) {
send_bw = &node;
} else {
send_fw = &node;
}
} else if (node.OpType() == "Recv") {
if (is_backward(node)) {
recv_bw = &node;
} else {
recv_fw = &node;
}
}
}
// except the last partion, each partition should have send forward and recv backward
if (stageIdx == 0 || stageIdx == 1) {
ASSERT_TRUE(send_fw && recv_bw);
} else {
ASSERT_TRUE(!send_fw && !recv_bw);
}
// except the first partion, each partition should have recv forward and send backward
if (stageIdx == 1 || stageIdx == 2) {
ASSERT_TRUE(recv_fw && send_bw);
} else {
ASSERT_TRUE(!recv_fw && !send_bw);
}

auto mp = model->ToProto();
std::ofstream ofs(filename + "_back.onnx", std::ofstream::binary);
mp.SerializeToOstream(&ofs);
ofs.close();
};

for (int i = 0; i < 3; ++i) {
std::string name = filename_base + std::to_string(i);
load_gradient_graph(name);
PathString name = filename_base + ORT_TSTR(std::to_string(i));
load_gradient_graph(i, name);
}
}

Expand Down
17 changes: 13 additions & 4 deletions orttraining/tools/scripts/pipeline_model_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ def split_graph(model, split_edge_groups):
upstream_nodes_output_index.append(i)
# assuming all tensors are of type float
element_types.append(1)
for info in model.graph.value_info:
if info.name == id:
output_shapes.append(info.type)
for info in model.graph.value_info:
if info.name == id:
output_shapes.append(info.type)

send_input_signal_name = 'send_input_signal' + str(cut_index)
send_signal = model.graph.input.add()
Expand Down Expand Up @@ -227,6 +227,7 @@ def insert_identity(model, all_cut_inputs):
count += 1
split_edges.append(new_edge_name)
updated_edges[i.edgeId] = new_edge_name
need_shape_inference = True
else:
split_edges.append(i.edgeId)
split_edge_groups.append(split_edges)
Expand Down Expand Up @@ -257,6 +258,14 @@ def get_index(node_list, node):
found = [i for i, n in enumerate(node_list) if n == node]
return found[0] if found else None

def get_identity_index_for_deleting(node_list, node):
for i, n in enumerate(node_list):
# The node's input name has been changed during send/recv insertion,
# but it is sufficient to just compare the type and outputs.
if (n.op_type == 'Identity' and n.output == node.output):
return i
return None

# traverse the graph, group connected nodes and generate subgraph


Expand All @@ -269,7 +278,7 @@ def generate_subgraph(model, start_nodes, identity_node_list):
# remove added identity node before copy to subgraph
identity_node_index = []
for n in identity_node_list:
identity_node_index.append(get_index(main_graph.graph.node, n))
identity_node_index.append(get_identity_index_for_deleting(main_graph.graph.node, n))
identity_node_index.sort(reverse=True)

for i in reversed(range(len(main_graph.graph.node))):
Expand Down

0 comments on commit 0f6c888

Please sign in to comment.