forked from kduxin/firelang
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
21 changed files
with
1,391 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
import os | ||
|
||
debug_on = {"on": True, "off": False}.get(os.environ.get("FIRE_DEBUG", "off"), False) | ||
|
||
from .measure import * | ||
from .function import * | ||
from .models import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .base import * | ||
from .functional import * | ||
from .compose import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,193 @@ | ||
from __future__ import division, annotations | ||
from typing import Mapping, Any, Iterable, Callable, Union | ||
import operator as op | ||
from torch.nn import Module, ModuleList | ||
import firelang | ||
from firelang.stack import StackingSlicing | ||
|
||
__all__ = ["Functional"] | ||
|
||
|
||
class Functional(StackingSlicing): | ||
def __init__( | ||
self, | ||
locals_: Mapping[str, Any], | ||
unsliceable_params: Iterable[str] = [], | ||
prev=[], | ||
operator=None, | ||
): | ||
StackingSlicing.__init__( | ||
self, locals_=locals_, unsliceable_params=unsliceable_params | ||
) | ||
prevm = [p for p in prev if isinstance(p, Module)] | ||
self.prevm = ModuleList(prevm) if len(prevm) else prevm | ||
self.prev = prev | ||
self.operator = operator | ||
|
||
def __add__(self, func_or_scalar): | ||
if isinstance(func_or_scalar, StackingSlicing): | ||
assert self.stack_size == func_or_scalar.stack_size | ||
return Functional( | ||
locals_={"stack_size": self.stack_size}, | ||
prev=[self, func_or_scalar], | ||
operator=op.add, | ||
) | ||
|
||
def __sub__(self, func_or_scalar): | ||
if isinstance(func_or_scalar, StackingSlicing): | ||
assert self.stack_size == func_or_scalar.stack_size | ||
return Functional( | ||
locals_={"stack_size": self.stack_size}, | ||
prev=[self, func_or_scalar], | ||
operator=op.sub, | ||
) | ||
|
||
def __mul__(self, other: Union[float, Functional, firelang.Measure]): | ||
""" | ||
Args: | ||
other (Union[float, Functional, Measure]): | ||
- if is `float` or `Functional`: generate a new Functional. | ||
- if is `Measure`, compute the paired integral. | ||
""" | ||
if isinstance(other, StackingSlicing) or isinstance(other, firelang.Measure): | ||
assert self.stack_size == other.stack_size | ||
|
||
if isinstance(other, float) or isinstance(other, Functional): | ||
return Functional( | ||
locals_={"stack_size": self.stack_size}, | ||
prev=[self, other], | ||
operator=op.mul, | ||
) | ||
elif isinstance(other, firelang.Measure): | ||
return other.integral(self) | ||
else: | ||
raise TypeError( | ||
f"`other` must be a float or Functional or Measure object, not {type(other)}." | ||
) | ||
|
||
def __truediv__(self, func_or_scalar): | ||
if isinstance(func_or_scalar, StackingSlicing): | ||
assert self.stack_size == func_or_scalar.stack_size | ||
return Functional( | ||
locals_={"stack_size": self.stack_size}, | ||
prev=[self, func_or_scalar], | ||
operator=op.truediv, | ||
) | ||
|
||
def __pow__(self, pow: float): | ||
return Functional( | ||
locals_={"stack_size": self.stack_size}, | ||
prev=[self, pow], | ||
operator=op.pow, | ||
) | ||
|
||
def __neg__(self): | ||
return Functional( | ||
locals_={"stack_size": self.stack_size}, prev=[self], operator=op.neg | ||
) | ||
|
||
def neg(self): | ||
return Functional( | ||
locals_={"stack_size": self.stack_size}, prev=[self], operator=op.neg | ||
) | ||
|
||
def __abs__(self): | ||
return Functional( | ||
locals_={"stack_size": self.stack_size}, prev=[self], operator=op.abs | ||
) | ||
|
||
def abs(self): | ||
return Functional( | ||
locals_={"stack_size": self.stack_size}, prev=[self], operator=op.abs | ||
) | ||
|
||
def apply_op(self, mapping: Callable, *other_nodes): | ||
return Functional( | ||
locals_={"stack_size": self.stack_size}, | ||
prev=[self, *other_nodes], | ||
operator=mapping, | ||
) | ||
|
||
def forward(self, x, *args, **kwargs): | ||
operator = self.operator | ||
prev = self.prev | ||
if operator is None: | ||
return NotImplementedError | ||
elif operator is op.add: | ||
return prev[0].forward(x, *args, **kwargs) + prev[1].forward( | ||
x, *args, **kwargs | ||
) | ||
elif operator is op.sub: | ||
return prev[0].forward(x, *args, **kwargs) - prev[1].forward( | ||
x, *args, **kwargs | ||
) | ||
elif operator is op.mul: | ||
return prev[0].forward(x, *args, **kwargs) * prev[1].forward( | ||
x, *args, **kwargs | ||
) | ||
elif operator is op.truediv: | ||
return prev[0].forward(x, *args, **kwargs) / prev[1].forward( | ||
x, *args, **kwargs | ||
) | ||
elif operator is op.pow: | ||
return prev[0].forward(x, *args, **kwargs) ** prev[1] | ||
elif operator is op.neg: | ||
return -prev[0].forward(x, *args, **kwargs) | ||
elif operator is op.abs: | ||
return prev[0].forward(x, *args, **kwargs).abs() | ||
elif isinstance(operator, Callable): | ||
return operator(prev, x, *args, **kwargs) | ||
else: | ||
raise ValueError(f"Unrecognized operator: {operator}") | ||
|
||
def __call__(self, *args, **kwargs): | ||
return self.forward(*args, **kwargs) | ||
|
||
def is_leaf(self): | ||
return not hasattr(self, "prev") or not len(self.prev) | ||
|
||
def __getitem__(self, idx): | ||
if self.is_leaf(): | ||
newop = StackingSlicing.__getitem__(self, idx) | ||
else: | ||
sliced = [ | ||
node[idx] if isinstance(node, StackingSlicing) else node | ||
for node in self.prev | ||
] | ||
newop = sliced[0].apply_op(self.operator, *sliced[1:]) | ||
return newop | ||
|
||
def __repr__(self): | ||
if self.is_leaf(): | ||
return Module.__repr__(self) + f", stack_size={self.stack_size}" | ||
else: | ||
segs = [f"{self.__class__.__name__}("] | ||
for i, node in enumerate(self.prev): | ||
for j, line in enumerate(repr(node).split("\n")): | ||
if j == 0: | ||
segs.append(" " + f"prev[{i}]: " + line) | ||
else: | ||
segs.append(" " + line) | ||
op_name = ( | ||
self.operator.__name__ | ||
if hasattr(self.operator, "__name__") | ||
else self.operator.__class__.__name__ | ||
) | ||
segs.append(f"), operator={op_name}") | ||
return "\n".join(segs) | ||
|
||
def restack(self, stack_size): | ||
if self.is_leaf(): | ||
newop = StackingSlicing.restack(self, stack_size) | ||
else: | ||
stacked = [ | ||
node.restack(stack_size) if hasattr(node, "restack") else node | ||
for node in self.prev | ||
] | ||
newop = stacked[0].apply_op(self.operator, *stacked[1:]) | ||
return newop | ||
|
||
stack = restack | ||
|
||
def __matmul__(self, measure: firelang.Measure): | ||
return measure.integral(self, cross=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .dense import * | ||
from .planar import * | ||
from .multilayer import * | ||
from .common import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import torch | ||
|
||
|
||
def sigmoid_deriv(x): | ||
y = torch.sigmoid(x) | ||
return y * (1 - y) | ||
|
||
|
||
def tanh_deriv(x): | ||
y = torch.tanh(x) | ||
return 1 - y**2 | ||
|
||
|
||
def identity(x): | ||
return x | ||
|
||
|
||
def identity_deriv(x): | ||
return torch.ones_like(x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
from typing_extensions import Literal | ||
from functools import partial | ||
import numpy as np | ||
import torch | ||
from torch import Tensor | ||
from torch import nn | ||
from torch.nn import Parameter | ||
from .common import identity, identity_deriv, sigmoid_deriv, tanh_deriv | ||
from ..base import Functional | ||
|
||
__all__ = [ | ||
"Perceptron", | ||
] | ||
|
||
|
||
class Perceptron(Functional): | ||
def __init__( | ||
self, | ||
input_dim, | ||
output_dim, | ||
activation="sigmoid", | ||
norm: Literal[None, "batch", "layer"] = None, | ||
stack_size=1, | ||
): | ||
Functional.__init__(self, locals()) | ||
|
||
scale = 0.1 / (input_dim + output_dim) ** 0.5 | ||
self.A = Parameter( | ||
torch.empty(stack_size, output_dim, input_dim).normal_(0, scale) | ||
) | ||
self.b = Parameter(torch.zeros(stack_size, output_dim)) | ||
|
||
if activation is None: | ||
self.act, self.actderiv = identity, identity_deriv | ||
elif activation == "sigmoid": | ||
self.act, self.actderiv = torch.sigmoid, sigmoid_deriv | ||
elif activation == "tanh": | ||
self.act, self.actderiv = torch.tanh, tanh_deriv | ||
else: | ||
raise ValueError(activation) | ||
|
||
if norm is None: | ||
self.normalizer = identity | ||
elif norm == "batch": | ||
self.normalizer = nn.BatchNorm1d(output_dim) | ||
elif norm == "layer": | ||
self.normalizer = nn.LayerNorm(output_dim) | ||
else: | ||
raise ValueError(norm) | ||
|
||
self.input_dim = input_dim | ||
self.output_dim = output_dim | ||
self.stack_size = stack_size | ||
self.activation = activation | ||
self.norm = norm | ||
|
||
def forward(self, x, cross=False) -> Tensor: | ||
""" | ||
Parameters: | ||
x: (stack2, batch, input_dim) | ||
when cross=True, also possible to get (stack1, stack2, batch, dim) | ||
Returns: | ||
if cross == True: | ||
(stack1, stack2, batch, output_dim) | ||
elif cross == False (and stack == stack2): | ||
(stack1, batch, output_dim) | ||
""" | ||
if cross == True: | ||
if x.ndim == 3: | ||
x = torch.einsum("tbj,sij->stbi", x, self.A) + self.b[:, None, None, :] | ||
elif x.ndim == 4: | ||
x = torch.einsum("stbj,sij->stbi", x, self.A) + self.b[:, None, None, :] | ||
else: | ||
raise ValueError(x.ndim) | ||
else: | ||
x = torch.einsum("sbj,sij->sbi", x, self.A) + self.b[:, None, :] | ||
xshape = x.shape | ||
|
||
x = self.normalizer(x.reshape(np.prod(xshape[:-1]), xshape[-1])).reshape( | ||
*xshape | ||
) | ||
x = self.act(x) | ||
return x | ||
|
||
def jacob(self, x) -> Tensor: | ||
""" | ||
Parameters: | ||
x: (stack, batch, input_dim) | ||
Returns: | ||
jac: (stack, batch, output_dim, input_dim) | ||
""" | ||
assert self.norm is None | ||
a = torch.einsum("sbj,sij->sbi", x, self.A) + self.b.unsqueeze( | ||
1 | ||
) # (stack, batch, output_dim) | ||
ad = self.actderiv(a) # (stack, batch, output_dim) | ||
jacob = ad[..., None] * self.A.unsqueeze( | ||
1 | ||
) # (stack, batch, output_dim, input_dim) | ||
return jacob | ||
|
||
def jacdet(self, x) -> Tensor: | ||
raise NotImplementedError |
Oops, something went wrong.