diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index dcb8ac5f965345..8aed9e296fad9a 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -419,6 +419,42 @@ def forward(self, x): x = torch.randn(2, 3, 10, 50, 100, requires_grad=True) self.run_test(model, (x,), rtol=1e-3, atol=1e-6) + def test_conv_tbc(self): + from torch.nn.modules.utils import _single + + class ConvTBC(torch.nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, padding=0): + super(ConvTBC, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _single(kernel_size) + self.padding = _single(padding) + + self.weight = torch.nn.Parameter( + torch.Tensor(self.kernel_size[0], in_channels, out_channels) + ) + self.bias = torch.nn.Parameter(torch.Tensor(out_channels)) + self.reset_parameters() + + def reset_parameters(self): + torch.nn.init.xavier_normal_(self.weight) + torch.nn.init.zeros_(self.bias) + + def conv_tbc(self, input): + return torch.conv_tbc( + input.contiguous(), self.weight, self.bias, self.padding[0] + ) + + def forward(self, input): + return self.conv_tbc(input) + + in_channels = 3 + out_channels = 5 + kernel_size = 5 + model = ConvTBC(in_channels, out_channels, kernel_size, padding=0) + x = torch.randn(10, 7, in_channels, requires_grad=True) + self.run_test(model, (x,), atol=1e-5) + def test_reshape_constant_fold(self): class Reshape(torch.nn.Module): def __init__(self, ): diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py index 145e7905b46446..5da68c5b3d8faa 100644 --- a/torch/onnx/symbolic_opset9.py +++ b/torch/onnx/symbolic_opset9.py @@ -1686,7 +1686,15 @@ def conv_tbc(g, input, weight, bias, pad): if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: return g.op("ATen", input, weight, bias, operator_s="conv_tbc", pad_i=pad) else: - return sym_help._onnx_unsupported('conv_tbc') + # input must have 3 dimensions, see: + # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/ConvolutionTBC.cpp#L8-L10 + # input = (time, batch, in_channels) + # weight = (kernel_width, in_channels, out_channels) + # bias = (out_channels,) + input = g.op("Transpose", input, perm_i=[1, 2, 0]) + weight = g.op("Transpose", weight, perm_i=[2, 1, 0]) + conv = conv1d(g, input, weight, bias, [1], [pad], [1], 1) + return g.op("Transpose", conv, perm_i=[2, 0, 1]) @parse_args('v', 'i', 'i')