Skip to content

Commit

Permalink
make recv/send signal as initializer
Browse files Browse the repository at this point in the history
  • Loading branch information
xzhu1900 committed Apr 21, 2020
1 parent 5739b42 commit 87b0cfd
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions orttraining/tools/scripts/pipeline_model_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,13 @@ def split_graph(model, split_edge_groups):
if info.name == id:
output_shapes.append(info.type)

send_signal = model.graph.input.add()
send_signal.CopyFrom(helper.make_tensor_value_info(
'send_input_signal' + str(cut_index), onnx.TensorProto.BOOL, None))
send_signal = helper.make_tensor(
'send_input_signal' + str(cut_index), TensorProto.BOOL, (), (True,))
model.graph.initializer.extend([send_signal])

recv_signal = model.graph.input.add()
recv_signal.CopyFrom(helper.make_tensor_value_info(
'recv_input_signal' + str(cut_index), onnx.TensorProto.BOOL, None))
recv_signal = helper.make_tensor(
'recv_input_signal' + str(cut_index), TensorProto.BOOL, (), (True,))
model.graph.initializer.extend([recv_signal])

send_dst_rank = helper.make_tensor(
'send_dst_rank' + str(cut_index), TensorProto.INT64, (), (cut_index + 1,))
Expand Down

0 comments on commit 87b0cfd

Please sign in to comment.