Skip to content

Commit

Permalink
FIRETensor: supporting multi-dimensional stacked functionals and meas…
Browse files Browse the repository at this point in the history
…ures.
  • Loading branch information
kduxin committed Nov 7, 2022
1 parent d9b1d67 commit b560992
Show file tree
Hide file tree
Showing 11 changed files with 459 additions and 377 deletions.
76 changes: 46 additions & 30 deletions firelang/function/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import division, annotations
from typing import Mapping, Any, Iterable, Callable, Union
from typing import Tuple, Mapping, Any, Iterable, Callable, Union
import operator as op
import numpy as np
from torch.nn import Module, ModuleList
import firelang
from firelang.stack import StackingSlicing
from firelang.stack import StackingSlicing, _parse_shape

__all__ = ["Functional"]

Expand All @@ -26,18 +27,18 @@ def __init__(

def __add__(self, func_or_scalar):
if isinstance(func_or_scalar, StackingSlicing):
assert self.stack_size == func_or_scalar.stack_size
assert self.shape == func_or_scalar.shape
return Functional(
locals_={"stack_size": self.stack_size},
locals_={"shape": self.shape},
prev=[self, func_or_scalar],
operator=op.add,
)

def __sub__(self, func_or_scalar):
if isinstance(func_or_scalar, StackingSlicing):
assert self.stack_size == func_or_scalar.stack_size
assert self.shape == func_or_scalar.shape
return Functional(
locals_={"stack_size": self.stack_size},
locals_={"shape": self.shape},
prev=[self, func_or_scalar],
operator=op.sub,
)
Expand All @@ -50,60 +51,52 @@ def __mul__(self, other: Union[float, Functional, firelang.Measure]):
- if is `Measure`, compute the paired integral.
"""
if isinstance(other, StackingSlicing) or isinstance(other, firelang.Measure):
assert self.stack_size == other.stack_size
assert self.shape == other.shape

if isinstance(other, float) or isinstance(other, Functional):
return Functional(
locals_={"stack_size": self.stack_size},
locals_={"shape": self.shape},
prev=[self, other],
operator=op.mul,
)
elif isinstance(other, firelang.Measure):
return other.integral(self, sum=False)
return other.integral(self)
else:
raise TypeError(
f"`other` must be a float or Functional or Measure object, not {type(other)}."
)

def __truediv__(self, func_or_scalar):
if isinstance(func_or_scalar, StackingSlicing):
assert self.stack_size == func_or_scalar.stack_size
assert self.shape == func_or_scalar.shape
return Functional(
locals_={"stack_size": self.stack_size},
locals_={"shape": self.shape},
prev=[self, func_or_scalar],
operator=op.truediv,
)

def __pow__(self, pow: float):
return Functional(
locals_={"stack_size": self.stack_size},
locals_={"shape": self.shape},
prev=[self, pow],
operator=op.pow,
)

def __neg__(self):
return Functional(
locals_={"stack_size": self.stack_size}, prev=[self], operator=op.neg
)
return Functional(locals_={"shape": self.shape}, prev=[self], operator=op.neg)

def neg(self):
return Functional(
locals_={"stack_size": self.stack_size}, prev=[self], operator=op.neg
)
return Functional(locals_={"shape": self.shape}, prev=[self], operator=op.neg)

def __abs__(self):
return Functional(
locals_={"stack_size": self.stack_size}, prev=[self], operator=op.abs
)
return Functional(locals_={"shape": self.shape}, prev=[self], operator=op.abs)

def abs(self):
return Functional(
locals_={"stack_size": self.stack_size}, prev=[self], operator=op.abs
)
return Functional(locals_={"shape": self.shape}, prev=[self], operator=op.abs)

def apply_op(self, mapping: Callable, *other_nodes):
return Functional(
locals_={"stack_size": self.stack_size},
locals_={"shape": self.shape},
prev=[self, *other_nodes],
operator=mapping,
)
Expand Down Expand Up @@ -157,9 +150,31 @@ def __getitem__(self, idx):
newop = sliced[0].apply_op(self.operator, *sliced[1:])
return newop

def view(self, *shape, inplace: bool = False):
shape = _parse_shape(shape, num_elements=np.prod(self.shape))

if inplace:
if self.is_leaf():
StackingSlicing.view(self, shape, inplace=True)
else:
for node in self.prev:
if isinstance(node, Functional):
Functional.view(node, shape, inplace=True)
return self
else:
if self.is_leaf():
newop = StackingSlicing.view(self, shape)
else:
prev = [
Functional.view(node, shape) if isinstance(node, Functional) else node
for node in self.prev
]
newop = prev[0].apply_op(self.operator, *prev[1:])
return newop

def __repr__(self):
if self.is_leaf():
return Module.__repr__(self) + f", stack_size={self.stack_size}"
return Module.__repr__(self) + f", shape={self.shape}"
else:
segs = [f"{self.__class__.__name__}("]
for i, node in enumerate(self.prev):
Expand All @@ -176,12 +191,12 @@ def __repr__(self):
segs.append(f"), operator={op_name}")
return "\n".join(segs)

def restack(self, stack_size):
def restack(self, shape: Tuple[int] = None):
if self.is_leaf():
newop = StackingSlicing.restack(self, stack_size)
newop = StackingSlicing.restack(self, shape)
else:
stacked = [
node.restack(stack_size) if hasattr(node, "restack") else node
node.restack(shape) if hasattr(node, "restack") else node
for node in self.prev
]
newop = stacked[0].apply_op(self.operator, *stacked[1:])
Expand All @@ -190,4 +205,5 @@ def restack(self, stack_size):
stack = restack

def __matmul__(self, measure: firelang.Measure):
return measure.integral(self, cross=True)
assert isinstance(measure, firelang.Measure)
return measure.integral(self, cross=True)
2 changes: 1 addition & 1 deletion firelang/function/components/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ def identity(x):


def identity_deriv(x):
return torch.ones_like(x)
return torch.ones_like(x)
67 changes: 34 additions & 33 deletions firelang/function/components/dense.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing_extensions import Literal
from functools import partial
import numpy as np
import torch
from torch import Tensor
from torch import nn
from torch.nn import Parameter
from firelang.utils.shape import check_shape_consistency
from .common import identity, identity_deriv, sigmoid_deriv, tanh_deriv
from ..base import Functional

Expand All @@ -20,15 +20,14 @@ def __init__(
output_dim,
activation="sigmoid",
norm: Literal[None, "batch", "layer"] = None,
stack_size=1,
shape=(1,),
):
Functional.__init__(self, locals())

size = np.prod(shape)
scale = 0.1 / (input_dim + output_dim) ** 0.5
self.A = Parameter(
torch.empty(stack_size, output_dim, input_dim).normal_(0, scale)
)
self.b = Parameter(torch.zeros(stack_size, output_dim))
self.A = Parameter(torch.empty(size, output_dim, input_dim).normal_(0, scale))
self.b = Parameter(torch.zeros(size, output_dim))

if activation is None:
self.act, self.actderiv = identity, identity_deriv
Expand All @@ -50,32 +49,29 @@ def __init__(

self.input_dim = input_dim
self.output_dim = output_dim
self.stack_size = stack_size
self.shape = shape
self.activation = activation
self.norm = norm

def forward(self, x, cross=False) -> Tensor:
def forward(self, x) -> Tensor:
"""
Parameters:
x: (stack2, batch, input_dim)
when cross=True, also possible to get (stack1, stack2, batch, dim)
x: (...shape, input_dim)
Returns:
if cross == True:
(stack1, stack2, batch, output_dim)
elif cross == False (and stack == stack2):
(stack1, batch, output_dim)
(...shape, output_dim)
"""
if cross == True:
if x.ndim == 3:
x = torch.einsum("tbj,sij->stbi", x, self.A) + self.b[:, None, None, :]
elif x.ndim == 4:
x = torch.einsum("stbj,sij->stbi", x, self.A) + self.b[:, None, None, :]
else:
raise ValueError(x.ndim)
else:
x = torch.einsum("sbj,sij->sbi", x, self.A) + self.b[:, None, :]
xshape = x.shape
fshape = self.shape
(*xshape, input_dim) = x.shape
check_shape_consistency(fshape, xshape)

output_dim = self.output_dim
A = self.A.view(*fshape, output_dim, input_dim)
b = self.b.view(*fshape, output_dim)

x = torch.einsum("...i,...ji->...j", x, A) + b

# normalization
xshape = x.shape
x = self.normalizer(x.reshape(np.prod(xshape[:-1]), xshape[-1])).reshape(
*xshape
)
Expand All @@ -85,18 +81,23 @@ def forward(self, x, cross=False) -> Tensor:
def jacob(self, x) -> Tensor:
"""
Parameters:
x: (stack, batch, input_dim)
x: (...shape, input_dim)
Returns:
jac: (stack, batch, output_dim, input_dim)
jacob: (...shape, output_dim, input_dim)
"""
assert self.norm is None
a = torch.einsum("sbj,sij->sbi", x, self.A) + self.b.unsqueeze(
1
) # (stack, batch, output_dim)
ad = self.actderiv(a) # (stack, batch, output_dim)
jacob = ad[..., None] * self.A.unsqueeze(
1
) # (stack, batch, output_dim, input_dim)

fshape = self.shape
(*xshape, input_dim) = x.shape
check_shape_consistency(fshape, xshape)

output_dim = self.output_dim
A = self.A.view(*fshape, output_dim, input_dim)
b = self.b.view(*fshape, output_dim)

a = torch.einsum("...i,...ji->...j", x, A) + b
ad = self.actderiv(a) # (...shape, output_dim)
jacob = ad[..., None] * self.A # (...shape, output_dim, input_dim)
return jacob

def jacdet(self, x) -> Tensor:
Expand Down
Loading

0 comments on commit b560992

Please sign in to comment.