Skip to content

Commit

Permalink
unify input and initializer
Browse files Browse the repository at this point in the history
  • Loading branch information
xzhu1900 committed Apr 22, 2020
1 parent bca8f41 commit 87c1cfc
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 53 deletions.
Binary file modified onnxruntime/test/testdata/test_training_model_0.onnx
Binary file not shown.
Binary file modified onnxruntime/test/testdata/test_training_model_1.onnx
Binary file not shown.
Binary file modified onnxruntime/test/testdata/test_training_model_2.onnx
Binary file not shown.
70 changes: 25 additions & 45 deletions orttraining/orttraining/core/graph/pipeline_transformer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ Status AddRecordBackward(Graph& graph,
return Status::OK();
}

Status AddWaitForward(Graph& graph, Node* recv_fw, std::vector<std::string>& new_input_names) {
Status AddWaitForward(Graph& graph, Node* /* recv_fw */, std::vector<std::string>& new_input_names) {
// Append old_input to input_args and return its pass-through value. Note that
// input_args and output_args are Wait's inputs and outputs, respectively.
auto update_wait_input_output = [&](NodeArg* old_input,
Expand All @@ -119,56 +119,36 @@ Status AddWaitForward(Graph& graph, Node* recv_fw, std::vector<std::string>& new
return wait_output;
};

if (recv_fw) {
// if we have recv op in forward pass (at the begining of the graph), add the WaitEvent op before that.
std::vector<NodeArg*> input_args;
std::vector<NodeArg*> output_args;
AddInputEvent(graph, "WaitEvent", true /* is_forward */, input_args, new_input_names);

// recv's first input is the signal input. Re-direct it to WaitEvent's input.
auto& input_signal = recv_fw->MutableInputDefs()[0];
auto& wait_output = update_wait_input_output(input_signal, input_args, output_args);
input_signal = &wait_output;

graph.AddNode(graph.GenerateNodeName("WaitEvent"),
"WaitEvent",
"",
input_args,
output_args,
nullptr,
kMSDomain);
} else {
// the first stage doesn't have recv_fw. Add Wait for all inputs.
std::vector<NodeArg*> input_args;
std::vector<NodeArg*> output_args;
AddInputEvent(graph, "WaitEvent", true /* is_forward */, input_args, new_input_names);
const std::vector<const NodeArg*>& graph_inputs = graph.GetInputs();
std::vector<NodeArg*> input_args;
std::vector<NodeArg*> output_args;
AddInputEvent(graph, "WaitEvent", true /* is_forward */, input_args, new_input_names);
const std::vector<const NodeArg*>& graph_inputs = graph.GetInputsIncludingInitializers();

if (graph_inputs.size() == 0){
ORT_THROW("Graph ", graph.Name(), " doesn't have any inputs.");
}
if (graph_inputs.size() == 0){
ORT_THROW("Graph ", graph.Name(), " doesn't have any inputs.");
}

for (auto& input_arg : graph_inputs) {
NodeArg* mutable_input = graph.GetNodeArg(input_arg->Name());
auto& wait_output = update_wait_input_output(mutable_input, input_args, output_args);
std::vector<Node*> nodes = graph.GetMutableConsumerNodes(input_arg->Name());
for (auto& consumer_node : nodes) {
for (auto& i : consumer_node->MutableInputDefs()) {
if (i->Name() == input_arg->Name()) {
// if the node is fed by input, re-direct it to be fed by WaitEvent's output.
i = &wait_output;
}
for (auto& input_arg : graph_inputs) {
NodeArg* mutable_input = graph.GetNodeArg(input_arg->Name());
auto& wait_output = update_wait_input_output(mutable_input, input_args, output_args);
std::vector<Node*> nodes = graph.GetMutableConsumerNodes(input_arg->Name());
for (auto& consumer_node : nodes) {
for (auto& i : consumer_node->MutableInputDefs()) {
if (i->Name() == input_arg->Name()) {
// if the node is fed by input, re-direct it to be fed by WaitEvent's output.
i = &wait_output;
}
}
}
graph.AddNode(graph.GenerateNodeName("WaitEvent"),
"WaitEvent",
"",
input_args,
output_args,
nullptr,
kMSDomain);
}
graph.AddNode(graph.GenerateNodeName("WaitEvent"),
"WaitEvent",
"",
input_args,
output_args,
nullptr,
kMSDomain);

return Status::OK();
}

Expand Down
30 changes: 22 additions & 8 deletions orttraining/tools/scripts/pipeline_model_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,20 +50,36 @@ def split_graph(model, split_edge_groups):
if info.name == id:
output_shapes.append(info.type)

send_input_signal_name = 'send_input_signal' + str(cut_index)
send_signal = model.graph.input.add()
send_signal.CopyFrom(helper.make_tensor_value_info(
send_input_signal_name, onnx.TensorProto.BOOL, None))
send_signal = helper.make_tensor(
'send_input_signal' + str(cut_index), TensorProto.BOOL, (), (True,))
send_input_signal_name, TensorProto.BOOL, (), (True,))
model.graph.initializer.extend([send_signal])

recv_input_signal_name = 'recv_input_signal' + str(cut_index)
recv_signal = model.graph.input.add()
recv_signal.CopyFrom(helper.make_tensor_value_info(
recv_input_signal_name, onnx.TensorProto.BOOL, None))
recv_signal = helper.make_tensor(
'recv_input_signal' + str(cut_index), TensorProto.BOOL, (), (True,))
recv_input_signal_name, TensorProto.BOOL, (), (True,))
model.graph.initializer.extend([recv_signal])

send_dst_rank_name = 'send_dst_rank' + str(cut_index)
send_dst_rank = model.graph.input.add()
send_dst_rank.CopyFrom(helper.make_tensor_value_info(
send_dst_rank_name, onnx.TensorProto.INT64, None))
send_dst_rank = helper.make_tensor(
'send_dst_rank' + str(cut_index), TensorProto.INT64, (), (cut_index + 1,))
send_dst_rank_name, TensorProto.INT64, (), (cut_index + 1,))
model.graph.initializer.extend([send_dst_rank])

recv_src_rank_name = 'recv_src_rank' + str(cut_index)
recv_src_rank = model.graph.input.add()
recv_src_rank.CopyFrom(helper.make_tensor_value_info(
recv_src_rank_name, onnx.TensorProto.INT64, None))
recv_src_rank = helper.make_tensor(
'recv_src_rank' + str(cut_index), TensorProto.INT64, (), (cut_index,))
recv_src_rank_name, TensorProto.INT64, (), (cut_index,))
model.graph.initializer.extend([recv_src_rank])

# output signal from send after cut
Expand All @@ -79,8 +95,7 @@ def split_graph(model, split_edge_groups):
new_send = model.graph.node.add()
new_send.CopyFrom(helper.make_node(
'Send',
inputs=['send_input_signal' +
str(cut_index), 'send_dst_rank' + str(cut_index)],
inputs=[send_input_signal_name, send_dst_rank_name],
outputs=['send_output_signal' + str(cut_index)],
tag=0,
domain=ms_domain,
Expand All @@ -90,8 +105,7 @@ def split_graph(model, split_edge_groups):
new_receive = model.graph.node.add()
new_receive.CopyFrom(helper.make_node(
'Recv',
inputs=['recv_input_signal' +
str(cut_index), 'recv_src_rank' + str(cut_index)],
inputs=[recv_input_signal_name, recv_src_rank_name],
outputs=['receive_output_signal' + str(cut_index)],
tag=0,
domain=ms_domain,
Expand Down

0 comments on commit 87c1cfc

Please sign in to comment.