Skip to content

Commit

Permalink
firelang library v1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
kduxin committed Oct 15, 2022
1 parent ed152b4 commit 216c3ca
Show file tree
Hide file tree
Showing 21 changed files with 1,391 additions and 0 deletions.
7 changes: 7 additions & 0 deletions firelang/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import os

debug_on = {"on": True, "off": False}.get(os.environ.get("FIRE_DEBUG", "off"), False)

from .measure import *
from .function import *
from .models import *
3 changes: 3 additions & 0 deletions firelang/function/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .base import *
from .functional import *
from .compose import *
193 changes: 193 additions & 0 deletions firelang/function/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
from __future__ import division, annotations
from typing import Mapping, Any, Iterable, Callable, Union
import operator as op
from torch.nn import Module, ModuleList
import firelang
from firelang.stack import StackingSlicing

__all__ = ["Functional"]


class Functional(StackingSlicing):
def __init__(
self,
locals_: Mapping[str, Any],
unsliceable_params: Iterable[str] = [],
prev=[],
operator=None,
):
StackingSlicing.__init__(
self, locals_=locals_, unsliceable_params=unsliceable_params
)
prevm = [p for p in prev if isinstance(p, Module)]
self.prevm = ModuleList(prevm) if len(prevm) else prevm
self.prev = prev
self.operator = operator

def __add__(self, func_or_scalar):
if isinstance(func_or_scalar, StackingSlicing):
assert self.stack_size == func_or_scalar.stack_size
return Functional(
locals_={"stack_size": self.stack_size},
prev=[self, func_or_scalar],
operator=op.add,
)

def __sub__(self, func_or_scalar):
if isinstance(func_or_scalar, StackingSlicing):
assert self.stack_size == func_or_scalar.stack_size
return Functional(
locals_={"stack_size": self.stack_size},
prev=[self, func_or_scalar],
operator=op.sub,
)

def __mul__(self, other: Union[float, Functional, firelang.Measure]):
"""
Args:
other (Union[float, Functional, Measure]):
- if is `float` or `Functional`: generate a new Functional.
- if is `Measure`, compute the paired integral.
"""
if isinstance(other, StackingSlicing) or isinstance(other, firelang.Measure):
assert self.stack_size == other.stack_size

if isinstance(other, float) or isinstance(other, Functional):
return Functional(
locals_={"stack_size": self.stack_size},
prev=[self, other],
operator=op.mul,
)
elif isinstance(other, firelang.Measure):
return other.integral(self)
else:
raise TypeError(
f"`other` must be a float or Functional or Measure object, not {type(other)}."
)

def __truediv__(self, func_or_scalar):
if isinstance(func_or_scalar, StackingSlicing):
assert self.stack_size == func_or_scalar.stack_size
return Functional(
locals_={"stack_size": self.stack_size},
prev=[self, func_or_scalar],
operator=op.truediv,
)

def __pow__(self, pow: float):
return Functional(
locals_={"stack_size": self.stack_size},
prev=[self, pow],
operator=op.pow,
)

def __neg__(self):
return Functional(
locals_={"stack_size": self.stack_size}, prev=[self], operator=op.neg
)

def neg(self):
return Functional(
locals_={"stack_size": self.stack_size}, prev=[self], operator=op.neg
)

def __abs__(self):
return Functional(
locals_={"stack_size": self.stack_size}, prev=[self], operator=op.abs
)

def abs(self):
return Functional(
locals_={"stack_size": self.stack_size}, prev=[self], operator=op.abs
)

def apply_op(self, mapping: Callable, *other_nodes):
return Functional(
locals_={"stack_size": self.stack_size},
prev=[self, *other_nodes],
operator=mapping,
)

def forward(self, x, *args, **kwargs):
operator = self.operator
prev = self.prev
if operator is None:
return NotImplementedError
elif operator is op.add:
return prev[0].forward(x, *args, **kwargs) + prev[1].forward(
x, *args, **kwargs
)
elif operator is op.sub:
return prev[0].forward(x, *args, **kwargs) - prev[1].forward(
x, *args, **kwargs
)
elif operator is op.mul:
return prev[0].forward(x, *args, **kwargs) * prev[1].forward(
x, *args, **kwargs
)
elif operator is op.truediv:
return prev[0].forward(x, *args, **kwargs) / prev[1].forward(
x, *args, **kwargs
)
elif operator is op.pow:
return prev[0].forward(x, *args, **kwargs) ** prev[1]
elif operator is op.neg:
return -prev[0].forward(x, *args, **kwargs)
elif operator is op.abs:
return prev[0].forward(x, *args, **kwargs).abs()
elif isinstance(operator, Callable):
return operator(prev, x, *args, **kwargs)
else:
raise ValueError(f"Unrecognized operator: {operator}")

def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)

def is_leaf(self):
return not hasattr(self, "prev") or not len(self.prev)

def __getitem__(self, idx):
if self.is_leaf():
newop = StackingSlicing.__getitem__(self, idx)
else:
sliced = [
node[idx] if isinstance(node, StackingSlicing) else node
for node in self.prev
]
newop = sliced[0].apply_op(self.operator, *sliced[1:])
return newop

def __repr__(self):
if self.is_leaf():
return Module.__repr__(self) + f", stack_size={self.stack_size}"
else:
segs = [f"{self.__class__.__name__}("]
for i, node in enumerate(self.prev):
for j, line in enumerate(repr(node).split("\n")):
if j == 0:
segs.append(" " + f"prev[{i}]: " + line)
else:
segs.append(" " + line)
op_name = (
self.operator.__name__
if hasattr(self.operator, "__name__")
else self.operator.__class__.__name__
)
segs.append(f"), operator={op_name}")
return "\n".join(segs)

def restack(self, stack_size):
if self.is_leaf():
newop = StackingSlicing.restack(self, stack_size)
else:
stacked = [
node.restack(stack_size) if hasattr(node, "restack") else node
for node in self.prev
]
newop = stacked[0].apply_op(self.operator, *stacked[1:])
return newop

stack = restack

def __matmul__(self, measure: firelang.Measure):
return measure.integral(self, cross=True)
4 changes: 4 additions & 0 deletions firelang/function/components/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .dense import *
from .planar import *
from .multilayer import *
from .common import *
19 changes: 19 additions & 0 deletions firelang/function/components/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import torch


def sigmoid_deriv(x):
y = torch.sigmoid(x)
return y * (1 - y)


def tanh_deriv(x):
y = torch.tanh(x)
return 1 - y**2


def identity(x):
return x


def identity_deriv(x):
return torch.ones_like(x)
103 changes: 103 additions & 0 deletions firelang/function/components/dense.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from typing_extensions import Literal
from functools import partial
import numpy as np
import torch
from torch import Tensor
from torch import nn
from torch.nn import Parameter
from .common import identity, identity_deriv, sigmoid_deriv, tanh_deriv
from ..base import Functional

__all__ = [
"Perceptron",
]


class Perceptron(Functional):
def __init__(
self,
input_dim,
output_dim,
activation="sigmoid",
norm: Literal[None, "batch", "layer"] = None,
stack_size=1,
):
Functional.__init__(self, locals())

scale = 0.1 / (input_dim + output_dim) ** 0.5
self.A = Parameter(
torch.empty(stack_size, output_dim, input_dim).normal_(0, scale)
)
self.b = Parameter(torch.zeros(stack_size, output_dim))

if activation is None:
self.act, self.actderiv = identity, identity_deriv
elif activation == "sigmoid":
self.act, self.actderiv = torch.sigmoid, sigmoid_deriv
elif activation == "tanh":
self.act, self.actderiv = torch.tanh, tanh_deriv
else:
raise ValueError(activation)

if norm is None:
self.normalizer = identity
elif norm == "batch":
self.normalizer = nn.BatchNorm1d(output_dim)
elif norm == "layer":
self.normalizer = nn.LayerNorm(output_dim)
else:
raise ValueError(norm)

self.input_dim = input_dim
self.output_dim = output_dim
self.stack_size = stack_size
self.activation = activation
self.norm = norm

def forward(self, x, cross=False) -> Tensor:
"""
Parameters:
x: (stack2, batch, input_dim)
when cross=True, also possible to get (stack1, stack2, batch, dim)
Returns:
if cross == True:
(stack1, stack2, batch, output_dim)
elif cross == False (and stack == stack2):
(stack1, batch, output_dim)
"""
if cross == True:
if x.ndim == 3:
x = torch.einsum("tbj,sij->stbi", x, self.A) + self.b[:, None, None, :]
elif x.ndim == 4:
x = torch.einsum("stbj,sij->stbi", x, self.A) + self.b[:, None, None, :]
else:
raise ValueError(x.ndim)
else:
x = torch.einsum("sbj,sij->sbi", x, self.A) + self.b[:, None, :]
xshape = x.shape

x = self.normalizer(x.reshape(np.prod(xshape[:-1]), xshape[-1])).reshape(
*xshape
)
x = self.act(x)
return x

def jacob(self, x) -> Tensor:
"""
Parameters:
x: (stack, batch, input_dim)
Returns:
jac: (stack, batch, output_dim, input_dim)
"""
assert self.norm is None
a = torch.einsum("sbj,sij->sbi", x, self.A) + self.b.unsqueeze(
1
) # (stack, batch, output_dim)
ad = self.actderiv(a) # (stack, batch, output_dim)
jacob = ad[..., None] * self.A.unsqueeze(
1
) # (stack, batch, output_dim, input_dim)
return jacob

def jacdet(self, x) -> Tensor:
raise NotImplementedError
Loading

0 comments on commit 216c3ca

Please sign in to comment.