Skip to content

Commit

Permalink
functional: use parse_shape from utils
Browse files Browse the repository at this point in the history
  • Loading branch information
kduxin committed Nov 8, 2022
1 parent 3be29d1 commit 37c0f6d
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions firelang/function/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import numpy as np
from torch.nn import Module, ModuleList
import firelang
from firelang.stack import StackingSlicing, _parse_shape
from firelang.stack import StackingSlicing
from firelang.utils.shape import parse_shape

__all__ = ["Functional"]

Expand Down Expand Up @@ -148,7 +149,7 @@ def __getitem__(self, idx):
return newop

def view(self, *shape, inplace: bool = False):
shape = _parse_shape(shape, num_elements=int(np.prod(self.shape)))
shape = parse_shape(shape, num_elements=int(np.prod(self.shape)))

if inplace:
if self.is_leaf():
Expand All @@ -163,7 +164,9 @@ def view(self, *shape, inplace: bool = False):
newop = StackingSlicing.view(self, shape)
else:
prev = [
Functional.view(node, shape) if isinstance(node, Functional) else node
Functional.view(node, shape)
if isinstance(node, Functional)
else node
for node in self.prev
]
newop = prev[0].apply_op(self.operator, *prev[1:])
Expand Down Expand Up @@ -203,4 +206,4 @@ def restack(self, shape: Tuple[int] = None):

def __matmul__(self, measure: firelang.Measure):
assert isinstance(measure, firelang.Measure)
return measure.integral(self, cross=True)
return measure.integral(self, cross=True)

0 comments on commit 37c0f6d

Please sign in to comment.