diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 4dc49d20ff36..983bce0255d9 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -212,6 +212,20 @@ def _softmax_module(self, node: fx.Node) -> relax.Var: assert dim is not None return self.block_builder.emit(relax.op.nn.softmax(x, dim)) + def _inplace_tril_triu(self, op: Callable) -> Callable: + from torch import fx + + def convert(node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + k = node.args[1] if len(node.args) > 1 else 0 + assert isinstance(k, int) + + mutated = self.block_builder.emit(op(x, k)) + self.env[node.args[0]] = mutated + return mutated + + return convert + def _tril_triu(self, op: Callable) -> Callable: from torch import fx @@ -356,6 +370,29 @@ def _baddbmm(self, node: fx.Node) -> relax.Var: res = bias if res is None else self.block_builder.emit(relax.op.add(res, bias)) return res + def _batch_norm_2d_module(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + weight = self.params[module.weight] + bias = self.params[module.bias] + running_mean = self._convert_torch_tensor_to_relax(module.running_mean) + running_var = self._convert_torch_tensor_to_relax(module.running_var) + eps = module.eps + + res_tuple = self.block_builder.emit( + relax.op.nn.batch_norm( + x, + weight, + bias, + running_mean, + running_var, + axis=1, + epsilon=eps, + ) + ) + + return self.block_builder.emit(relax.TupleGetItem(res_tuple, 0)) + def _conv1d_transpose_impl( self, x: relax.Expr, @@ -683,6 +720,40 @@ def _conv3d_module(self, node: fx.Node) -> relax.Var: groups=module.groups, ) + def _cross_entropy(self, node: fx.Node) -> relax.Expr: + preds = self.env[node.args[0]] + targets = self.env[node.args[1]] + weights = self.env.get(node.kwargs["weight"], None) + reduction = node.kwargs["reduction"] + ignore_index = node.kwargs["ignore_index"] + + return self.block_builder.emit( + relax.op.nn.nll_loss( + relax.op.nn.log_softmax(preds), targets, weights, reduction, ignore_index + ) + ) + + def _cross_entropy_module(self, node: fx.Node) -> relax.Expr: + preds = self.env[node.args[0]] + targets = self.env[node.args[1]] + module = self.named_modules[node.target] + + weights = module.weight + if weights is not None: + if weights in self.params: + weights = self.params[weights] + else: + weights = relax.const(weights.numpy(), preds.struct_info.dtype) + + reduction = module.reduction + ignore_index = module.ignore_index + + return self.block_builder.emit( + relax.op.nn.nll_loss( + relax.op.nn.log_softmax(preds), targets, weights, reduction, ignore_index + ) + ) + def _einsum(self, node: fx.Node) -> relax.Var: import torch # type: ignore @@ -740,6 +811,80 @@ def _group_norm_module(self, node: fx.Node) -> relax.Var: ) ) + def _interpolate(self, node: fx.Node) -> relax.Var: + # torch.nn.functional.interpolate( + # input, size=None, scale_factor=None, mode='nearest', align_corners=None, + # recompute_scale_factor=None, antialias=False) + # (TODO) this is a temporary implementation for interpolate that only considers NCHW layout + # it basically replicates the implementation in tvm.relay.frontend.pytorch + data = self.env[node.args[0]] + size = ( + node.args[1] + if len(node.args) > 1 + else (node.kwargs["size"] if "size" in node.kwargs else None) + ) + scale_factor = ( + node.args[2] + if len(node.args) > 2 + else (node.kwargs["scale_factor"] if "scale_factor" in node.kwargs else None) + ) + method = ( + node.args[3] + if len(node.args) > 3 + else (node.kwargs["mode"] if "mode" in node.kwargs else "nearest") + ) + align_corners = ( + node.args[4] + if len(node.args) > 4 + else (node.kwargs["align_corners"] if "align_corners" in node.kwargs else None) + ) + recompute_scale_factor = ( + node.args[5] + if len(node.args) > 5 + else ( + node.kwargs["recompute_scale_factor"] + if "recompute_scale_factor" in node.kwargs + else None + ) + ) + antialias = ( + node.args[6] + if len(node.args) > 6 + else (node.kwargs["antialias"] if "antialias" in node.kwargs else False) + ) + + assert recompute_scale_factor is None + assert antialias is False + + if size is None: + shape = self.shape_of(data) + assert isinstance(shape, relax.ShapeExpr) + if isinstance(scale_factor, tuple): + assert len(scale_factor) == len(shape) - 2 + size = tuple( + int(shape[i].value * scale_factor[i - 2]) for i in range(2, len(shape)) + ) + else: + size = tuple(int(shape[i].value * scale_factor) for i in range(2, len(shape))) + + if method.startswith("nearest"): + method = "nearest_neighbor" + elif method[0:2] == "bi": + method = method[2:] + + if method == "nearest_neighbor": + coord_trans = "asymmetric" + elif align_corners: + coord_trans = "align_corners" + else: + coord_trans = "half_pixel" + + return self.block_builder.emit( + relax.op.image.resize2d( + data, size, layout="NCHW", method=method, coordinate_transformation_mode=coord_trans + ) + ) + def _layer_norm_impl(self, x, gamma, beta, eps, normalized_shape) -> relax.Var: from torch.fx.immutable_collections import immutable_list import numpy as np # type: ignore @@ -913,230 +1058,106 @@ def convert(node: fx.Node): return convert - ########## DataType ########## - - def _float(self, node: fx.Node) -> relax.Var: - return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], "float32")) - - def _half(self, node: fx.Node) -> relax.Var: - return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], "float16")) + ########## Manipulation ########## - def _to(self, node: fx.Node) -> relax.Var: - import torch + def _cat(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) + return self.block_builder.emit(relax.op.concat(args[0], axis=axis)) + def _chunk(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] - if len(node.args) == 2: - if isinstance(node.args[1], torch.dtype): - dtype = TorchFXImporter._convert_data_type(node.args[1], self.env) - return self.block_builder.emit(relax.op.astype(x, dtype)) - elif "dtype" in node.kwargs: - dtype = TorchFXImporter._convert_data_type(node.kwargs["dtype"], self.env) - return self.block_builder.emit(relax.op.astype(x, dtype)) - return x + chunks = node.args[1] + dim = node.args[2] if len(node.args) > 2 else node.kwargs.get("dim", 0) + return self.block_builder.emit(relax.op.split(x, chunks, dim)) - def _type(self, node: fx.Node) -> relax.Var: + def _cumsum(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] - dtype = TorchFXImporter._convert_data_type(node.args[1], self.env) - return self.block_builder.emit(relax.op.astype(x, dtype)) - ########## Creation ########## + dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", None) + if "dtype" in node.kwargs: + dtype = self._convert_data_type(str(node.kwargs["dtype"]), self.env) + else: + dtype = None + if "out" in node.kwargs: + raise ValueError("specifying out for cumsum is not supported yet") - def _arange(self, node: fx.Node) -> relax.Var: - import torch + return self.block_builder.emit(relax.op.cumsum(x, dim, dtype)) - start_end_step = [None, None, None] - if "start" in node.kwargs: - start_end_step[0] = node.kwargs["start"] - if "end" in node.kwargs: - start_end_step[1] = node.kwargs["end"] - if "step" in node.kwargs: - start_end_step[2] = node.kwargs["step"] + def _expand(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + sizes = args[1:] if len(args) > 2 else args[1] + broadcast_shape, in_shape = [], self.shape_of(args[0]) + for idx, i in enumerate(sizes): + if isinstance(i, int) and i == -1: + broadcast_shape.append(in_shape[idx]) + else: + broadcast_shape.append(i) + return self.block_builder.emit(relax.op.broadcast_to(args[0], broadcast_shape)) - if len(node.args) == 1: - assert start_end_step[1] is None - start_end_step[1] = node.args[0] - elif len(node.args) == 2: - assert start_end_step[0] is None - assert start_end_step[1] is None - start_end_step[0] = node.args[0] - start_end_step[1] = node.args[1] - elif len(node.args) == 3: - assert start_end_step[0] is None - assert start_end_step[1] is None - assert start_end_step[2] is None - start_end_step[0] = node.args[0] - start_end_step[1] = node.args[1] - start_end_step[2] = node.args[2] + def _flatten_impl(self, x, start_dim, end_dim) -> relax.Var: + shape = self.shape_of(x) + start_dim = start_dim if start_dim >= 0 else len(shape) + start_dim + end_dim = end_dim if end_dim >= 0 else len(shape) + end_dim + flattened = reduce(lambda x, y: x * y, [shape[i] for i in range(start_dim, end_dim + 1)]) + new_shape = ( + [shape[i] for i in range(0, start_dim)] + + [flattened] + + [shape[i] for i in range(end_dim + 1, len(shape))] + ) + return self.block_builder.emit(relax.op.reshape(x, new_shape)) - if start_end_step[0] is None: - start_end_step[0] = 0 - if start_end_step[2] is None: - start_end_step[2] = 1 + def _flatten(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + start_dim = node.args[1] if len(node.args) >= 2 else node.kwargs.get("start_dim", 0) + end_dim = node.args[2] if len(node.args) == 3 else node.kwargs.get("end_dim", -1) + return self._flatten_impl(x, start_dim, end_dim) - if "dtype" in node.kwargs: - dtype = TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), self.env) - elif any([isinstance(x, float) for x in start_end_step]): - dtype = TorchFXImporter._convert_data_type(torch.get_default_dtype()) - else: - dtype = "int64" - start_end_step = [ - self.env[x] if isinstance(x, torch.fx.Node) else x for x in start_end_step - ] - return self.block_builder.emit(relax.op.arange(*start_end_step, dtype=dtype)) + def _flatten_module(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + start_dim = module.start_dim + end_dim = module.end_dim + return self._flatten_impl(x, start_dim, end_dim) - def _empty(self, node: fx.Node) -> relax.Var: - dtype = TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), self.env) - return self.block_builder.emit(relax.op.zeros(node.args, dtype)) + def _permute(self, node: fx.Node) -> relax.Var: + import torch # type: ignore - def _inplace_fill(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) x = args[0] - dtype = x.struct_info.dtype - value = args[1] if isinstance(args[1], relax.Expr) else relax.const(args[1], dtype) - filled = self.block_builder.emit(relax.op.full(x.struct_info.shape, value, dtype)) - self.env[node.args[0]] = filled - return filled + dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:] + return self.block_builder.emit(relax.op.permute_dims(x, dims)) - def _tensor(self, node: fx.Node) -> relax.Var: - dtype = node.kwargs["dtype"] if "dtype" in node.kwargs else None - if isinstance(node.args[0], float): - return relax.const(node.args[0], dtype if dtype is not None else "float32") - elif isinstance(node.args[0], int): - return relax.const(node.args[0], dtype if dtype is not None else "int64") - raise ValueError("torch.tensor with value not a float or int is not accepted") + def _repeat(self, node: fx.Node) -> relax.Var: + import torch # type: ignore - def _inplace_tril_triu(self, op: Callable) -> Callable: - from torch import fx + args = self.retrieve_args(node) + x = args[0] + dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:] + return self.block_builder.emit(relax.op.tile(x, dims)) - def convert(node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - k = node.args[1] if len(node.args) > 1 else 0 - assert isinstance(k, int) - - mutated = self.block_builder.emit(op(x, k)) - self.env[node.args[0]] = mutated - return mutated - - return convert - - def _new_ones(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - self_var = args[0] - size = args[1:] - if not isinstance(size, (list, tuple)): - size = (size,) - size = relax.ShapeExpr(size) - return self.block_builder.emit( - relax.op.full( - size, - relax.const(1, self_var.struct_info.dtype), - self_var.struct_info.dtype, - ) - ) - - def _ones(self, node: fx.Node) -> relax.Var: - import torch - - args = self.retrieve_args(node) - size = args[0] - if not isinstance(size, (list, tuple)): - size = (size,) - size = relax.ShapeExpr(size) - dtype = ( - TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), self.env) - if "dtype" in node.kwargs - else TorchFXImporter._convert_data_type(torch.get_default_dtype(), self.env) - ) - return self.block_builder.emit( - relax.op.full( - size, - relax.const(1, dtype), - dtype, - ) - ) - - def _full(self, node: fx.Node) -> relax.Var: - import torch - - args = self.retrieve_args(node) - size = args[0] - if not isinstance(size, (list, tuple)): - size = (size,) - size = relax.ShapeExpr(size) - dtype = ( - TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), self.env) - if "dtype" in node.kwargs - else TorchFXImporter._convert_data_type(torch.get_default_dtype(), self.env) - ) - value = args[1] if isinstance(args[1], relax.expr.Constant) else relax.const(args[1], dtype) - return self.block_builder.emit( - relax.op.full( - size, - value, - dtype, - ) - ) - - ########## Manipulation ########## - - def _cat(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) - return self.block_builder.emit(relax.op.concat(args[0], axis=axis)) + def _reshape(self, node: fx.Node) -> relax.Var: + import torch # type: ignore - def _expand(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) - broadcast_shape, in_shape = [], self.shape_of(args[0]) - for idx, i in enumerate(args[1:]): - if isinstance(i, int) and i == -1: - broadcast_shape.append(in_shape[idx]) - else: - broadcast_shape.append(i) - return self.block_builder.emit(relax.op.broadcast_to(args[0], broadcast_shape)) + x = args[0] + dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:] + return self.block_builder.emit(relax.op.reshape(x, dims)) - def _flatten(self, node: fx.Node) -> relax.Var: + def _size(self, node: fx.Node) -> relax.Expr: x = self.env[node.args[0]] - if node.target in self.named_modules: - module = self.named_modules[node.target] - start_dim = module.start_dim - end_dim = module.end_dim - else: - start_dim = node.args[1] if len(node.args) >= 2 else 0 - end_dim = node.args[2] if len(node.args) == 3 else -1 shape = self.shape_of(x) - start_dim = start_dim if start_dim >= 0 else len(shape) + start_dim - end_dim = end_dim if end_dim >= 0 else len(shape) + end_dim - flattened = reduce(lambda x, y: x * y, [shape[i] for i in range(start_dim, end_dim + 1)]) - new_shape = ( - [shape[i] for i in range(0, start_dim)] - + [flattened] - + [shape[i] for i in range(end_dim + 1, len(shape))] - ) - return self.block_builder.emit(relax.op.reshape(x, new_shape)) - - def _permute(self, node: fx.Node) -> relax.Var: - import torch # type: ignore - - args = self.retrieve_args(node) - if isinstance(args[1], (torch.Size, tuple, list)): - return self.block_builder.emit(relax.op.permute_dims(args[0], tuple(args[1]))) - return self.block_builder.emit(relax.op.permute_dims(args[0], args[1:])) - - def _reshape(self, node: fx.Node) -> relax.Var: - import torch # type: ignore - - args = self.retrieve_args(node) - if isinstance(args[1], (torch.Size, tuple, list)): - return self.block_builder.emit(relax.op.reshape(args[0], tuple(args[1]))) - return self.block_builder.emit(relax.op.reshape(args[0], args[1:])) + if len(node.args) == 1: + assert isinstance(shape, relax.ShapeExpr) + return shape + assert len(node.args) == 2 + idx = node.args[1] + return self.shape_of(x)[idx].value def _split(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] split_size = node.args[1] - if "dim" in node.kwargs: - dim = node.kwargs["dim"] - else: - dim = 0 + dim = node.args[2] if len(node.args) > 2 else node.kwargs.get("dim", 0) if isinstance(split_size, (list, tuple)): n_section = [] for s in split_size[:-1]: @@ -1146,17 +1167,18 @@ def _split(self, node: fx.Node) -> relax.Var: n_section = (self.shape_of(x)[dim].value + split_size - 1) // split_size return self.block_builder.emit(relax.op.split(x, n_section, dim)) - def _chunk(self, node: fx.Node) -> relax.Var: + def _squeeze(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] - chunks = node.args[1] + dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", None) + return self.block_builder.emit(relax.op.squeeze(x, dim)) - if "dim" in node.kwargs: - dim = node.kwargs["dim"] - elif len(node.args) > 2: - dim = node.args[2] - else: - dim = 0 - return self.block_builder.emit(relax.op.split(x, chunks, dim)) + def _tile(self, node: fx.Node) -> relax.Var: + import torch # type: ignore + + args = self.retrieve_args(node) + x = args[0] + dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:] + return self.block_builder.emit(relax.op.tile(x, dims)) def _transpose(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) @@ -1164,50 +1186,80 @@ def _transpose(self, node: fx.Node) -> relax.Var: full_idx[args[1]], full_idx[args[2]] = full_idx[args[2]], full_idx[args[1]] return self.block_builder.emit(relax.op.permute_dims(args[0], full_idx)) - def _squeeze(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - - if "dim" in node.kwargs: - dim = node.kwargs["dim"] - elif len(node.args) > 1: - dim = node.args[1] - else: - dim = None - return self.block_builder.emit(relax.op.squeeze(x, dim)) + ########## Creation ########## - def _repeat(self, node: fx.Node) -> relax.Var: + def _arange(self, node: fx.Node) -> relax.Var: import torch # type: ignore - args = self.retrieve_args(node) - if isinstance(args[1], (torch.Size, tuple, list)): - return self.block_builder.emit(relax.op.tile(args[0], tuple(args[1]))) - return self.block_builder.emit(relax.op.tile(args[0], args[1:])) - - def _tile(self, node: fx.Node) -> relax.Var: - import torch # type: ignore + start_end_step = [None, None, None] + if "start" in node.kwargs: + start_end_step[0] = node.kwargs["start"] + if "end" in node.kwargs: + start_end_step[1] = node.kwargs["end"] + if "step" in node.kwargs: + start_end_step[2] = node.kwargs["step"] - args = self.retrieve_args(node) - if isinstance(args[1], (torch.Size, tuple, list)): - return self.block_builder.emit(relax.op.tile(args[0], tuple(args[1]))) - return self.block_builder.emit(relax.op.tile(args[0], args[1:])) + if len(node.args) == 1: + assert start_end_step[1] is None + start_end_step[1] = node.args[0] + elif len(node.args) == 2: + assert start_end_step[0] is None + assert start_end_step[1] is None + start_end_step[0] = node.args[0] + start_end_step[1] = node.args[1] + elif len(node.args) == 3: + assert start_end_step[0] is None + assert start_end_step[1] is None + assert start_end_step[2] is None + start_end_step[0] = node.args[0] + start_end_step[1] = node.args[1] + start_end_step[2] = node.args[2] - def _cumsum(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] + if start_end_step[0] is None: + start_end_step[0] = 0 + if start_end_step[2] is None: + start_end_step[2] = 1 - if "dim" in node.kwargs: - dim = node.kwargs["dim"] - elif len(node.args) > 1: - dim = node.args[1] - else: - dim = None if "dtype" in node.kwargs: - dtype = TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), self.env) + dtype = self._convert_data_type(str(node.kwargs["dtype"]), self.env) + elif any([isinstance(x, float) for x in start_end_step]): + dtype = self._convert_data_type(torch.get_default_dtype()) else: - dtype = None - if "out" in node.kwargs: - raise ValueError("specifying out for cumsum is not supported yet") + dtype = "int64" + start_end_step = [ + self.env[x] if isinstance(x, torch.fx.Node) else x for x in start_end_step + ] + return self.block_builder.emit(relax.op.arange(*start_end_step, dtype=dtype)) - return self.block_builder.emit(relax.op.cumsum(x, dim, dtype)) + def _empty(self, node: fx.Node) -> relax.Var: + dtype = self._convert_data_type(str(node.kwargs["dtype"]), self.env) + return self.block_builder.emit(relax.op.zeros(node.args[0], dtype)) + + def _inplace_fill(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + dtype = x.struct_info.dtype + value = args[1] if isinstance(args[1], relax.Expr) else relax.const(args[1], dtype) + filled = self.block_builder.emit(relax.op.full(x.struct_info.shape, value, dtype)) + self.env[node.args[0]] = filled + return filled + + def _full(self, node: fx.Node) -> relax.Var: + import torch + + args = self.retrieve_args(node) + size = relax.ShapeExpr(args[0] if isinstance(args[0], (list, tuple)) else (args[0],)) + dtype = self._convert_data_type( + node.kwargs.get("dtype", torch.get_default_dtype()), self.env + ) + value = args[1] if isinstance(args[1], relax.expr.Constant) else relax.const(args[1], dtype) + return self.block_builder.emit( + relax.op.full( + size, + value, + dtype, + ) + ) def _index_select(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] @@ -1215,14 +1267,6 @@ def _index_select(self, node: fx.Node) -> relax.Var: index = self.env[node.args[2]] return self.block_builder.emit(relax.op.take(x, index, dim)) - def _masked_fill(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - mask = self.env[node.args[1]] - value = node.args[2] - rx_value = relax.const(value) - values = self.block_builder.emit(relax.op.full_like(x, rx_value)) - return self.block_builder.emit(relax.op.where(mask, values, x)) - def _inplace_masked_fill(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] mask = self.env[node.args[1]] @@ -1233,168 +1277,79 @@ def _inplace_masked_fill(self, node: fx.Node) -> relax.Var: self.env[node.args[0]] = output return output - ########## Neural Network ########## - - def _softmax(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - if node.target in self.named_modules: - module = self.named_modules[node.target] - dim = module.dim - else: - nargs = len(node.args) - dim = node.args[1] if nargs > 1 else node.kwargs["dim"] - assert dim is not None - return self.block_builder.emit(relax.op.nn.softmax(x, dim)) - - def _batch_norm_2d(self, node: fx.Node) -> relax.Var: + def _masked_fill(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] - module = self.named_modules[node.target] - weight = self.params[module.weight] - bias = self.params[module.bias] - running_mean = self._convert_torch_tensor_to_relax(module.running_mean) - running_var = self._convert_torch_tensor_to_relax(module.running_var) - eps = module.eps + mask = self.env[node.args[1]] + rx_value = relax.const(node.args[2]) + values = self.block_builder.emit(relax.op.full_like(x, rx_value)) + return self.block_builder.emit(relax.op.where(mask, values, x)) - res_tuple = self.block_builder.emit( - relax.op.nn.batch_norm( - x, - weight, - bias, - running_mean, - running_var, - axis=1, - epsilon=eps, + def _new_ones(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + self_var = args[0] + size = args[1] if isinstance(args[1], (list, tuple)) else args[1:] + if not isinstance(size, (list, tuple)): + size = (size,) + size = relax.ShapeExpr(size) + return self.block_builder.emit( + relax.op.full( + size, + relax.const(1, self_var.struct_info.dtype), + self_var.struct_info.dtype, ) ) - return self.block_builder.emit(relax.TupleGetItem(res_tuple, 0)) + def _ones(self, node: fx.Node) -> relax.Var: + import torch - def _interpolate(self, node: fx.Node) -> relax.Var: - # torch.nn.functional.interpolate( - # input, size=None, scale_factor=None, mode='nearest', align_corners=None, - # recompute_scale_factor=None, antialias=False) - # (TODO) this is a temporary implementation for interpolate that only considers NCHW layout - # it basically replicates the implementation in tvm.relay.frontend.pytorch - data = self.env[node.args[0]] - size = ( - node.args[1] - if len(node.args) > 1 - else (node.kwargs["size"] if "size" in node.kwargs else None) - ) - scale_factor = ( - node.args[2] - if len(node.args) > 2 - else (node.kwargs["scale_factor"] if "scale_factor" in node.kwargs else None) - ) - method = ( - node.args[3] - if len(node.args) > 3 - else (node.kwargs["mode"] if "mode" in node.kwargs else "nearest") - ) - align_corners = ( - node.args[4] - if len(node.args) > 4 - else (node.kwargs["align_corners"] if "align_corners" in node.kwargs else None) - ) - recompute_scale_factor = ( - node.args[5] - if len(node.args) > 5 - else ( - node.kwargs["recompute_scale_factor"] - if "recompute_scale_factor" in node.kwargs - else None - ) - ) - antialias = ( - node.args[6] - if len(node.args) > 6 - else (node.kwargs["antialias"] if "antialias" in node.kwargs else False) + args = self.retrieve_args(node) + size = relax.ShapeExpr(args[0] if isinstance(args[0], (list, tuple)) else (args[0],)) + dtype = self._convert_data_type( + node.kwargs.get("dtype", torch.get_default_dtype()), self.env ) - - assert recompute_scale_factor is None - assert antialias is False - - if size is None: - shape = self.shape_of(data) - assert isinstance(shape, relax.ShapeExpr) - if isinstance(scale_factor, tuple): - assert len(scale_factor) == len(shape) - 2 - size = tuple( - int(shape[i].value * scale_factor[i - 2]) for i in range(2, len(shape)) - ) - else: - size = tuple(int(shape[i].value * scale_factor) for i in range(2, len(shape))) - - if method.startswith("nearest"): - method = "nearest_neighbor" - elif method[0:2] == "bi": - method = method[2:] - - if method == "nearest_neighbor": - coord_trans = "asymmetric" - elif align_corners: - coord_trans = "align_corners" - else: - coord_trans = "half_pixel" - return self.block_builder.emit( - relax.op.image.resize2d( - data, size, layout="NCHW", method=method, coordinate_transformation_mode=coord_trans + relax.op.full( + size, + relax.const(1, dtype), + dtype, ) ) - def _cross_entropy(self, node: fx.Node) -> relax.Expr: - preds = self.env[node.args[0]] - targets = self.env[node.args[1]] - - # functional.cross_entropy - if node.target not in self.named_modules: - weights = node.kwargs["weight"] - if weights is not None: - weights = self.env[weights] - reduction = node.kwargs["reduction"] - ignore_index = node.kwargs["ignore_index"] - - return self.block_builder.emit( - relax.op.nn.nll_loss( - relax.op.nn.log_softmax(preds), targets, weights, reduction, ignore_index - ) - ) + def _tensor(self, node: fx.Node) -> relax.Var: + dtype = node.kwargs.get("dtype", None) + if isinstance(node.args[0], float): + return relax.const(node.args[0], dtype if dtype is not None else "float32") + elif isinstance(node.args[0], int): + return relax.const(node.args[0], dtype if dtype is not None else "int64") + raise ValueError("torch.tensor with value not a float or int is not accepted") - module = self.named_modules[node.target] + ########## DataType ########## - weights = module.weight - if weights is not None: - if weights in self.params: - weights = self.params[weights] - else: - weights = relax.const(weights.numpy(), preds.struct_info.dtype) - reduction = module.reduction - ignore_index = module.ignore_index + def _float(self, node: fx.Node) -> relax.Var: + return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], "float32")) - return self.block_builder.emit( - relax.op.nn.nll_loss( - relax.op.nn.log_softmax(preds), targets, weights, reduction, ignore_index - ) - ) + def _half(self, node: fx.Node) -> relax.Var: + return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], "float16")) - ########## Others ########## + def _to(self, node: fx.Node) -> relax.Var: + import torch - def _sym_size_int(self, node: fx.Node) -> relax.Expr: x = self.env[node.args[0]] - shape = self.shape_of(x) - idx = node.args[1] - return self.block_builder.emit(relax.const(shape[idx].value, "int32")) + if len(node.args) == 2: + if isinstance(node.args[1], torch.dtype): + dtype = TorchFXImporter._convert_data_type(node.args[1], self.env) + return self.block_builder.emit(relax.op.astype(x, dtype)) + elif "dtype" in node.kwargs: + dtype = TorchFXImporter._convert_data_type(node.kwargs["dtype"], self.env) + return self.block_builder.emit(relax.op.astype(x, dtype)) + return x - def _size(self, node: fx.Node) -> relax.Expr: + def _type(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] - shape = self.shape_of(x) - if len(node.args) == 1: - assert isinstance(shape, relax.ShapeExpr) - return shape - assert len(node.args) == 2 - idx = node.args[1] - return self.shape_of(x)[idx].value + dtype = TorchFXImporter._convert_data_type(node.args[1], self.env) + return self.block_builder.emit(relax.op.astype(x, dtype)) + + ########## Others ########## def _getattr(self, node: fx.Node) -> relax.Var: if isinstance(self.env[node.args[0]], relax.Expr): @@ -1485,6 +1440,12 @@ def _getitem(self, node: fx.Node) -> relax.Var: else: assert False + def _sym_size_int(self, node: fx.Node) -> relax.Expr: + x = self.env[node.args[0]] + shape = self.shape_of(x) + idx = node.args[1] + return self.block_builder.emit(relax.const(shape[idx].value, "int32")) + def create_convert_map(self): import operator from torch import nn @@ -1511,20 +1472,20 @@ def create_convert_map(self): # neural network nn.AdaptiveAvgPool2d: self._adaptive_avg_pool2d_module, nn.AvgPool2d: self._avg_pool2d_module, - nn.BatchNorm2d: self._batch_norm_2d, + nn.BatchNorm2d: self._batch_norm_2d_module, nn.Conv1d: self._conv1d_module, nn.Conv2d: self._conv2d_module, nn.Conv3d: self._conv3d_module, nn.ConvTranspose1d: self._conv1d_transpose_module, nn.ConvTranspose2d: self._conv2d_transpose_module, - nn.CrossEntropyLoss: self._cross_entropy, + nn.CrossEntropyLoss: self._cross_entropy_module, nn.GroupNorm: self._group_norm_module, nn.LayerNorm: self._layer_norm_module, nn.Linear: self._linear_module, nn.MaxPool2d: self._max_pool2d_module, nn.modules.sparse.Embedding: self._embedding_module, # tensor manipulation - nn.Flatten: self._flatten, + nn.Flatten: self._flatten_module, ## call_function and call_method # unary "acos": self._unary_op(relax.op.acos), @@ -1603,6 +1564,7 @@ def create_convert_map(self): "argmin": self._argmax_argmin(relax.op.argmin), # tensor manipulation "cat": self._cat, + "chunk": self._chunk, "concat": self._cat, "contiguous": lambda node: self.env[node.args[0]], "cumsum": self._cumsum, @@ -1622,7 +1584,6 @@ def create_convert_map(self): "view": self._reshape, # tensor creation "arange": self._arange, - "chunk": self._chunk, "empty": self._empty, "fill_": self._inplace_fill, "full": self._full, @@ -1632,11 +1593,11 @@ def create_convert_map(self): "new_ones": self._new_ones, "ones": self._ones, "tensor": self._tensor, - "to": self._to, # datatype "astype": self._type, "float": self._float, "half": self._half, + "to": self._to, "type": self._type, # other "getattr": self._getattr,