Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relax][PyTorch] Add support for torch.nn.functional.conv* #17325

Merged
merged 9 commits into from
Sep 3, 2024
Prev Previous commit
Next Next commit
add support for functional conv_transpose1d
  • Loading branch information
mshr-h committed Sep 1, 2024
commit 2200fed1acead591755f4ec5927d817f90bf6919
65 changes: 53 additions & 12 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,34 +807,74 @@ def _conv1d_functional(self, node: fx.node.Node) -> relax.Var:
groups=groups,
)

def _conv1d_transpose(self, node: fx.node.Node) -> relax.Var:
x = self.env[node.args[0]]
module = self.named_modules[node.target]
weight = self.params[module.weight]

def _conv1d_transpose_impl(
self,
x: relax.Expr,
weight: relax.Expr,
bias: Optional[relax.Expr],
strides: Optional[Tuple],
padding: Optional[Tuple],
dilation: Optional[Tuple],
groups: Optional[Tuple],
) -> relax.Var:
conv1d_transpose = self.block_builder.emit(
relax.op.nn.conv1d_transpose(
x,
weight,
strides=module.stride,
padding=module.padding,
dilation=module.dilation,
groups=module.groups,
strides=strides,
padding=padding,
dilation=dilation,
groups=groups,
data_layout="NCW",
kernel_layout="OIW",
out_dtype="float32",
)
)

if module.bias is None:
if bias is None:
return conv1d_transpose

bias = self.params[module.bias]
assert len(self.shape_of(bias)) == 1
bias = relax.op.reshape(bias, (1, -1, 1))

return self.block_builder.emit(relax.op.add(conv1d_transpose, bias))

def _conv1d_transpose(self, node: fx.node.Node) -> relax.Var:
x = self.env[node.args[0]]
module = self.named_modules[node.target]
weight = self.params[module.weight]
bias = None
if module.bias is not None:
bias = self.params[module.bias]

return self._conv1d_transpose_impl(
x,
weight,
bias=bias,
strides=module.stride,
padding=module.padding,
dilation=module.dilation,
groups=module.groups,
)

def _conv1d_transpose_functional(self, node: fx.node.Node) -> relax.Var:
args = self.retrieve_args(node)
x = args[0]
weight = args[1]
bias = args[2] if len(args) > 2 else None
stride = args[3] if len(args) > 3 else 1
padding = args[4] if len(args) > 4 else 0
dilation = args[5] if len(args) > 5 else 1
groups = args[6] if len(args) > 6 else 1
return self._conv1d_transpose_impl(
x,
weight,
bias=bias,
strides=stride,
padding=padding,
dilation=dilation,
groups=groups,
)

def _conv2d_impl(
self,
x: relax.Expr,
Expand Down Expand Up @@ -1522,6 +1562,7 @@ def create_convert_map(self):
"astype": self._type,
"matmul": self._matmul,
"conv1d": self._conv1d_functional,
"conv_transpose1d": self._conv1d_transpose_functional,
"conv2d": self._conv2d_functional,
"linear": self._linear_functional,
"addmm": self._addmm,
Expand Down
Loading