-
Notifications
You must be signed in to change notification settings - Fork 71
/
comparisons.py
48 lines (40 loc) · 1.79 KB
/
comparisons.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
__all__ = [
'OnnxCompare',
]
import torch
from torch import nn
from onnx2torch.node_converters.registry import add_converter
from onnx2torch.onnx_graph import OnnxGraph
from onnx2torch.onnx_node import OnnxNode
from onnx2torch.utils.common import OnnxToTorchModule
from onnx2torch.utils.common import OperationConverterResult
from onnx2torch.utils.common import onnx_mapping_from_node
_TORCH_FUNCTION_FROM_ONNX_TYPE = {
'Equal': torch.eq,
'Less': torch.less,
'LessOrEqual': torch.less_equal,
'Greater': torch.greater,
'GreaterOrEqual': torch.greater_equal,
}
class OnnxCompare(nn.Module, OnnxToTorchModule): # pylint: disable=missing-docstring
def __init__(self, operation_type: str):
super().__init__()
self.compare_function = _TORCH_FUNCTION_FROM_ONNX_TYPE[operation_type]
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring
return self.compare_function(x, y)
@add_converter(operation_type='Equal', version=7)
@add_converter(operation_type='Equal', version=11)
@add_converter(operation_type='Equal', version=13)
@add_converter(operation_type='Less', version=7)
@add_converter(operation_type='Less', version=9)
@add_converter(operation_type='Less', version=13)
@add_converter(operation_type='Greater', version=7)
@add_converter(operation_type='Greater', version=9)
@add_converter(operation_type='Greater', version=13)
@add_converter(operation_type='LessOrEqual', version=12)
@add_converter(operation_type='GreaterOrEqual', version=12)
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument
return OperationConverterResult(
torch_module=OnnxCompare(operation_type=node.operation_type),
onnx_mapping=onnx_mapping_from_node(node=node),
)