Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support symbolic for conv_tbc (#58359) #58692

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,42 @@ def forward(self, x):
x = torch.randn(2, 3, 10, 50, 100, requires_grad=True)
self.run_test(model, (x,), rtol=1e-3, atol=1e-6)

def test_conv_tbc(self):
from torch.nn.modules.utils import _single

class ConvTBC(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, padding=0):
super(ConvTBC, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _single(kernel_size)
self.padding = _single(padding)

self.weight = torch.nn.Parameter(
torch.Tensor(self.kernel_size[0], in_channels, out_channels)
)
self.bias = torch.nn.Parameter(torch.Tensor(out_channels))
self.reset_parameters()

def reset_parameters(self):
torch.nn.init.xavier_normal_(self.weight)
torch.nn.init.zeros_(self.bias)

def conv_tbc(self, input):
return torch.conv_tbc(
input.contiguous(), self.weight, self.bias, self.padding[0]
)

def forward(self, input):
return self.conv_tbc(input)

in_channels = 3
out_channels = 5
kernel_size = 5
model = ConvTBC(in_channels, out_channels, kernel_size, padding=0)
x = torch.randn(10, 7, in_channels, requires_grad=True)
self.run_test(model, (x,), atol=1e-5)

def test_reshape_constant_fold(self):
class Reshape(torch.nn.Module):
def __init__(self, ):
Expand Down
10 changes: 9 additions & 1 deletion torch/onnx/symbolic_opset9.py
Original file line number Diff line number Diff line change
Expand Up @@ -1686,7 +1686,15 @@ def conv_tbc(g, input, weight, bias, pad):
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
return g.op("ATen", input, weight, bias, operator_s="conv_tbc", pad_i=pad)
else:
return sym_help._onnx_unsupported('conv_tbc')
# input must have 3 dimensions, see:
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/ConvolutionTBC.cpp#L8-L10
# input = (time, batch, in_channels)
# weight = (kernel_width, in_channels, out_channels)
# bias = (out_channels,)
input = g.op("Transpose", input, perm_i=[1, 2, 0])
weight = g.op("Transpose", weight, perm_i=[2, 1, 0])
conv = conv1d(g, input, weight, bias, [1], [pad], [1], 1)
return g.op("Transpose", conv, perm_i=[2, 0, 1])


@parse_args('v', 'i', 'i')
Expand Down