From 9d4e9a477e49d54a58ce61e7f8639ea6a0dd861e Mon Sep 17 00:00:00 2001 From: David Fan <30608893+jiafatom@users.noreply.github.com> Date: Thu, 20 May 2021 13:20:50 -0700 Subject: [PATCH] Support symbolic for conv_tbc (#58359) This is a fix for exporting fairseq models, see: ```python model = torch.hub.load(github, 'conv.wmt14.en-fr', tokenizer='moses', bpe='subword_nmt') model = torch.hub.load(github, 'conv.wmt17.en-de', tokenizer='moses', bpe='subword_nmt') ``` With this fix, and comment out model script one line `GradMultiply`, these two models can be exported successfully with perf met. The original PR https://github.com/pytorch/pytorch/pull/57708 has merging issue, use this one instead. Co-authored-by: David [ghstack-poisoned] --- test/onnx/test_pytorch_onnx_onnxruntime.py | 36 ++++++++++++++++++++++ torch/onnx/symbolic_opset9.py | 10 +++++- 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index 87870b28310e1..51cfccc3d3987 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -394,6 +394,42 @@ def forward(self, x): x = torch.randn(2, 3, 10, 50, 100, requires_grad=True) self.run_test(model, (x,), rtol=1e-3, atol=1e-6) + def test_conv_tbc(self): + from torch.nn.modules.utils import _single + + class ConvTBC(torch.nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, padding=0): + super(ConvTBC, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _single(kernel_size) + self.padding = _single(padding) + + self.weight = torch.nn.Parameter( + torch.Tensor(self.kernel_size[0], in_channels, out_channels) + ) + self.bias = torch.nn.Parameter(torch.Tensor(out_channels)) + self.reset_parameters() + + def reset_parameters(self): + torch.nn.init.xavier_normal_(self.weight) + torch.nn.init.zeros_(self.bias) + + def conv_tbc(self, input): + return torch.conv_tbc( + input.contiguous(), self.weight, self.bias, self.padding[0] + ) + + def forward(self, input): + return self.conv_tbc(input) + + in_channels = 3 + out_channels = 5 + kernel_size = 5 + model = ConvTBC(in_channels, out_channels, kernel_size, padding=0) + x = torch.randn(10, 7, in_channels, requires_grad=True) + self.run_test(model, (x,), atol=1e-5) + def test_reshape_constant_fold(self): class Reshape(torch.nn.Module): def __init__(self, ): diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py index 88643765887bc..e4edfd562eedc 100644 --- a/torch/onnx/symbolic_opset9.py +++ b/torch/onnx/symbolic_opset9.py @@ -1684,7 +1684,15 @@ def conv_tbc(g, input, weight, bias, pad): if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: return g.op("ATen", input, weight, bias, operator_s="conv_tbc", pad_i=pad) else: - return sym_help._onnx_unsupported('conv_tbc') + # input must have 3 dimensions, see: + # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/ConvolutionTBC.cpp#L8-L10 + # input = (time, batch, in_channels) + # weight = (kernel_width, in_channels, out_channels) + # bias = (out_channels,) + input = g.op("Transpose", input, perm_i=[1, 2, 0]) + weight = g.op("Transpose", weight, perm_i=[2, 1, 0]) + conv = conv1d(g, input, weight, bias, [1], [pad], [1], 1) + return g.op("Transpose", conv, perm_i=[2, 0, 1]) @parse_args('v', 'i', 'i')