Skip to content

Commit

Permalink
Symbolic shape inference: fix rank for ConstantOfShape (microsoft#5912)
Browse files Browse the repository at this point in the history
  • Loading branch information
KeDengMS authored Nov 24, 2020
1 parent c2d6100 commit ee908eb
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion onnxruntime/python/tools/symbolic_shape_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,8 @@ def _infer_ConstantOfShape(self, node):
self.sympy_data_[node.output[0]] = np.ones([int(x) for x in sympy_shape], dtype=np.int64) * numpy_helper.to_array(get_attribute(node, 'value', 0))
else:
# create new dynamic shape
sympy_shape = self._new_symbolic_shape(self._get_shape_rank(node,0), node)
# note input0 is a 1D vector of shape, the new symbolic shape has the rank of the shape vector length
sympy_shape = self._new_symbolic_shape(self._get_shape(node,0)[0], node)

vi.CopyFrom(helper.make_tensor_value_info(node.output[0],
vi.type.tensor_type.elem_type,
Expand Down

0 comments on commit ee908eb

Please sign in to comment.