diff --git a/keras2onnx/_builtin.py b/keras2onnx/_builtin.py index cf6cb41e..7bd63466 100644 --- a/keras2onnx/_builtin.py +++ b/keras2onnx/_builtin.py @@ -778,39 +778,20 @@ def convert_tf_gather_nd(scope, operator, container): operator.outputs[0].full_name, name=operator.full_name) - -def _convert_tf_compare_equal(scope, operator, container, tf_op_string, onnx_op_string): - if operator.target_opset < 7: - raise ValueError(tf_op_string + " op is not supported for opset < 7") - oopb = OnnxOperatorBuilder(container, scope) - if operator.target_opset >= 9: - compare_node = oopb.add_node(onnx_op_string, - operator.input_full_names, - operator.full_name + '_' + onnx_op_string.lower()) - oopb.add_node_with_output('Not', - [compare_node], - operator.outputs[0].full_name, - name=operator.full_name) - else: - compare_input_0 = oopb.add_node('Cast', [operator.inputs[0].full_name], - operator.full_name + '_input_0_cast', to=oopb.float) - compare_input_1 = oopb.add_node('Cast', [operator.inputs[1].full_name], - operator.full_name + '_input_1_cast', to=oopb.float) - less_out = oopb.add_node(onnx_op_string, [compare_input_0, compare_input_1], - operator.full_name + '_' + onnx_op_string.lower()) - oopb.add_node_with_output('Not', less_out, - operator.output_full_names, - name=operator.full_name + '_not') - - @converter_func(TYPES.GreaterEqual) def convert_tf_greater_equal(scope, operator, container): - _convert_tf_compare_equal(scope, operator, container, 'GreaterEqual', 'Less') + oopb = OnnxOperatorBuilder(container, scope) + oopb.apply_op_with_output('apply_greater_or_equal', operator.input_full_names, + operator.output_full_names, + name=operator.full_name) @converter_func(TYPES.LessEqual) def convert_tf_less_equal(scope, operator, container): - _convert_tf_compare_equal(scope, operator, container, 'LessEqual', 'Greater') + oopb = OnnxOperatorBuilder(container, scope) + oopb.apply_op_with_output('apply_less_or_equal', operator.input_full_names, + operator.output_full_names, + name=operator.full_name) @converter_func(TYPES.LogicalAnd)