diff --git a/firelang/function/base.py b/firelang/function/base.py index 54ec056..9128738 100644 --- a/firelang/function/base.py +++ b/firelang/function/base.py @@ -4,7 +4,8 @@ import numpy as np from torch.nn import Module, ModuleList import firelang -from firelang.stack import StackingSlicing, _parse_shape +from firelang.stack import StackingSlicing +from firelang.utils.shape import parse_shape __all__ = ["Functional"] @@ -148,7 +149,7 @@ def __getitem__(self, idx): return newop def view(self, *shape, inplace: bool = False): - shape = _parse_shape(shape, num_elements=int(np.prod(self.shape))) + shape = parse_shape(shape, num_elements=int(np.prod(self.shape))) if inplace: if self.is_leaf(): @@ -163,7 +164,9 @@ def view(self, *shape, inplace: bool = False): newop = StackingSlicing.view(self, shape) else: prev = [ - Functional.view(node, shape) if isinstance(node, Functional) else node + Functional.view(node, shape) + if isinstance(node, Functional) + else node for node in self.prev ] newop = prev[0].apply_op(self.operator, *prev[1:]) @@ -203,4 +206,4 @@ def restack(self, shape: Tuple[int] = None): def __matmul__(self, measure: firelang.Measure): assert isinstance(measure, firelang.Measure) - return measure.integral(self, cross=True) \ No newline at end of file + return measure.integral(self, cross=True)